diff --git a/src/Framework/App.Ref/src/Microsoft.AspNetCore.App.Ref.sfxproj b/src/Framework/App.Ref/src/Microsoft.AspNetCore.App.Ref.sfxproj index f0a83c2e417b..bd685a879b2d 100644 --- a/src/Framework/App.Ref/src/Microsoft.AspNetCore.App.Ref.sfxproj +++ b/src/Framework/App.Ref/src/Microsoft.AspNetCore.App.Ref.sfxproj @@ -76,6 +76,11 @@ Private="false" OutputItemType="AspNetCoreAnalyzer" ReferenceOutputAssembly="false" /> + + diff --git a/src/Http/Http.Abstractions/src/Metadata/IDisableValidationMetadata.cs b/src/Http/Http.Abstractions/src/Metadata/IDisableValidationMetadata.cs new file mode 100644 index 000000000000..f45eff5a09fe --- /dev/null +++ b/src/Http/Http.Abstractions/src/Metadata/IDisableValidationMetadata.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.Metadata; + +/// +/// A marker interface which can be used to identify metadata that disables validation +/// on a given endpoint. +/// +public interface IDisableValidationMetadata +{ +} diff --git a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt index d7c55b7606ff..1af37e0a52a1 100644 --- a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt @@ -1,4 +1,45 @@ #nullable enable +abstract Microsoft.AspNetCore.Http.Validation.ValidatableParameterInfo.GetValidationAttributes() -> System.ComponentModel.DataAnnotations.ValidationAttribute![]! +abstract Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo.GetValidationAttributes() -> System.ComponentModel.DataAnnotations.ValidationAttribute![]! +Microsoft.AspNetCore.Http.Metadata.IDisableValidationMetadata Microsoft.AspNetCore.Http.ProducesResponseTypeMetadata.Description.get -> string? Microsoft.AspNetCore.Http.ProducesResponseTypeMetadata.Description.set -> void Microsoft.AspNetCore.Http.Metadata.IProducesResponseTypeMetadata.Description.get -> string? +Microsoft.AspNetCore.Http.Validation.IValidatableInfo +Microsoft.AspNetCore.Http.Validation.IValidatableInfo.ValidateAsync(object? value, Microsoft.AspNetCore.Http.Validation.ValidatableContext! context, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.ValueTask +Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver +Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver.TryGetValidatableParameterInfo(System.Reflection.ParameterInfo! parameterInfo, out Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) -> bool +Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver.TryGetValidatableTypeInfo(System.Type! type, out Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) -> bool +Microsoft.AspNetCore.Http.Validation.ValidatableContext +Microsoft.AspNetCore.Http.Validation.ValidatableContext.CurrentDepth.get -> int +Microsoft.AspNetCore.Http.Validation.ValidatableContext.CurrentDepth.set -> void +Microsoft.AspNetCore.Http.Validation.ValidatableContext.Prefix.get -> string! +Microsoft.AspNetCore.Http.Validation.ValidatableContext.Prefix.set -> void +Microsoft.AspNetCore.Http.Validation.ValidatableContext.ValidatableContext() -> void +Microsoft.AspNetCore.Http.Validation.ValidatableContext.ValidationContext.get -> System.ComponentModel.DataAnnotations.ValidationContext? +Microsoft.AspNetCore.Http.Validation.ValidatableContext.ValidationContext.set -> void +Microsoft.AspNetCore.Http.Validation.ValidatableContext.ValidationErrors.get -> System.Collections.Generic.Dictionary? +Microsoft.AspNetCore.Http.Validation.ValidatableContext.ValidationErrors.set -> void +Microsoft.AspNetCore.Http.Validation.ValidatableContext.ValidationOptions.get -> Microsoft.AspNetCore.Http.Validation.ValidationOptions! +Microsoft.AspNetCore.Http.Validation.ValidatableContext.ValidationOptions.set -> void +Microsoft.AspNetCore.Http.Validation.ValidatableParameterInfo +Microsoft.AspNetCore.Http.Validation.ValidatableParameterInfo.ValidatableParameterInfo(System.Type! parameterType, string! name, string! displayName) -> void +Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo +Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo.IsRequired.get -> bool +Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo.ValidatablePropertyInfo(System.Type! declaringType, System.Type! propertyType, string! name, string! displayName) -> void +Microsoft.AspNetCore.Http.Validation.ValidatableTypeAttribute +Microsoft.AspNetCore.Http.Validation.ValidatableTypeAttribute.ValidatableTypeAttribute() -> void +Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo +Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo.ValidatableTypeInfo(System.Type! type, System.Collections.Generic.IReadOnlyList! members) -> void +Microsoft.AspNetCore.Http.Validation.ValidationOptions +Microsoft.AspNetCore.Http.Validation.ValidationOptions.MaxDepth.get -> int +Microsoft.AspNetCore.Http.Validation.ValidationOptions.MaxDepth.set -> void +Microsoft.AspNetCore.Http.Validation.ValidationOptions.Resolvers.get -> System.Collections.Generic.IList! +Microsoft.AspNetCore.Http.Validation.ValidationOptions.TryGetValidatableParameterInfo(System.Reflection.ParameterInfo! parameterInfo, out Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableInfo) -> bool +Microsoft.AspNetCore.Http.Validation.ValidationOptions.TryGetValidatableTypeInfo(System.Type! type, out Microsoft.AspNetCore.Http.Validation.IValidatableInfo? validatableTypeInfo) -> bool +Microsoft.AspNetCore.Http.Validation.ValidationOptions.ValidationOptions() -> void +Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions +static Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services, System.Action? configureOptions = null) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! +virtual Microsoft.AspNetCore.Http.Validation.ValidatableParameterInfo.ValidateAsync(object? value, Microsoft.AspNetCore.Http.Validation.ValidatableContext! context, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.ValueTask +virtual Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo.ValidateAsync(object? value, Microsoft.AspNetCore.Http.Validation.ValidatableContext! context, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.ValueTask +virtual Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo.ValidateAsync(object? value, Microsoft.AspNetCore.Http.Validation.ValidatableContext! context, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.ValueTask diff --git a/src/Http/Http.Abstractions/src/Validation/IValidatableInfo.cs b/src/Http/Http.Abstractions/src/Validation/IValidatableInfo.cs new file mode 100644 index 000000000000..43242203d780 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/IValidatableInfo.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Represents an interface for validating a value. +/// +public interface IValidatableInfo +{ + /// + /// Validates the specified value. + /// + /// The value to validate. + /// The validation context. + /// + ValueTask ValidateAsync(object? value, ValidatableContext context, CancellationToken cancellationToken); +} diff --git a/src/Http/Http.Abstractions/src/Validation/IValidatableInfoResolver.cs b/src/Http/Http.Abstractions/src/Validation/IValidatableInfoResolver.cs new file mode 100644 index 000000000000..b4d4abe31c2d --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/IValidatableInfoResolver.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Provides an interface for resolving the validation information associated +/// with a given or . +/// +public interface IValidatableInfoResolver +{ + /// + /// Gets validation information for the specified type. + /// + /// The type to get validation information for. + /// + /// The output parameter that will contain the validatable information if found. + /// + /// if the validatable type information was found; otherwise, false. + bool TryGetValidatableTypeInfo(Type type, [NotNullWhen(true)] out IValidatableInfo? validatableInfo); + + /// + /// Gets validation information for the specified parameter. + /// + /// The parameter to get validation information for. + /// The output parameter that will contain the validatable information if found. + /// if the validatable parameter information was found; otherwise, false. + bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, [NotNullWhen(true)] out IValidatableInfo? validatableInfo); +} diff --git a/src/Http/Http.Abstractions/src/Validation/RuntimeValidatableParameterInfoResolver.cs b/src/Http/Http.Abstractions/src/Validation/RuntimeValidatableParameterInfoResolver.cs new file mode 100644 index 000000000000..1d901b192ea8 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/RuntimeValidatableParameterInfoResolver.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Reflection; + +namespace Microsoft.AspNetCore.Http.Validation; + +internal class RuntimeValidatableParameterInfoResolver : IValidatableInfoResolver +{ + public bool TryGetValidatableTypeInfo(Type type, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + + public bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + Debug.Assert(parameterInfo.Name != null, "Parameter must have name"); + var validationAttributes = parameterInfo + .GetCustomAttributes() + .ToArray(); + validatableInfo = new RuntimeValidatableParameterInfo( + parameterType: parameterInfo.ParameterType, + name: parameterInfo.Name, + displayName: GetDisplayName(parameterInfo), + validationAttributes: validationAttributes + ); + return true; + } + + private static string GetDisplayName(ParameterInfo parameterInfo) + { + var displayAttribute = parameterInfo.GetCustomAttribute(); + if (displayAttribute != null) + { + return displayAttribute.Name ?? parameterInfo.Name!; + } + + return parameterInfo.Name!; + } + + private class RuntimeValidatableParameterInfo( + Type parameterType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) : + ValidatableParameterInfo(parameterType, name, displayName) + { + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + + private readonly ValidationAttribute[] _validationAttributes = validationAttributes; + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/TypeExtensions.cs b/src/Http/Http.Abstractions/src/Validation/TypeExtensions.cs new file mode 100644 index 000000000000..980d1f1ff694 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/TypeExtensions.cs @@ -0,0 +1,113 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.Validation; + +internal static class TypeExtensions +{ + public static bool IsEnumerable(this Type type) + { + // Check if type itself is an IEnumerable + if (type.IsGenericType && + (type.GetGenericTypeDefinition() == typeof(IEnumerable<>) || + type.GetGenericTypeDefinition() == typeof(ICollection<>) || + type.GetGenericTypeDefinition() == typeof(List<>))) + { + return true; + } + + // Or an array + if (type.IsArray) + { + return true; + } + + // Then evaluate if it implements IEnumerable and is not a string + if (typeof(IEnumerable).IsAssignableFrom(type) && + type != typeof(string)) + { + return true; + } + + return false; + } + + public static bool IsNullable(this Type type) + { + if (type.IsValueType) + { + return false; + } + + if (type.IsGenericType && + type.GetGenericTypeDefinition() == typeof(Nullable<>)) + { + return true; + } + + return false; + } + + public static bool TryGetRequiredAttribute(this ValidationAttribute[] attributes, [NotNullWhen(true)] out RequiredAttribute? requiredAttribute) + { + foreach (var attribute in attributes) + { + if (attribute is RequiredAttribute requiredAttr) + { + requiredAttribute = requiredAttr; + return true; + } + } + + requiredAttribute = null; + return false; + } + + /// + /// Gets all types that the specified type implements or inherits from, including itself. + /// + /// The type to analyze. + /// A collection containing the type itself, all implemented interfaces, and all base types. + public static IEnumerable GetAllImplementedTypes([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] this Type type) + { + ArgumentNullException.ThrowIfNull(type); + + // Yield all interfaces directly and indirectly implemented by this type + foreach (var interfaceType in type.GetInterfaces()) + { + yield return interfaceType; + } + + // Finally, walk up the inheritance chain + var baseType = type.BaseType; + while (baseType != null && baseType != typeof(object)) + { + yield return baseType; + baseType = baseType.BaseType; + } + } + + /// + /// Determines whether the specified type implements the given interface. + /// + /// The type to check. + /// The interface type to check for. + /// True if the type implements the specified interface; otherwise, false. + public static bool ImplementsInterface(this Type type, Type interfaceType) + { + ArgumentNullException.ThrowIfNull(type); + ArgumentNullException.ThrowIfNull(interfaceType); + + // Check if interfaceType is actually an interface + if (!interfaceType.IsInterface) + { + throw new ArgumentException($"Type {interfaceType.FullName} is not an interface.", nameof(interfaceType)); + } + + return interfaceType.IsAssignableFrom(type); + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidatableContext.cs b/src/Http/Http.Abstractions/src/Validation/ValidatableContext.cs new file mode 100644 index 000000000000..9aa2f7e7f9d5 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidatableContext.cs @@ -0,0 +1,80 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Represents the context for validating a validatable object. +/// +public sealed class ValidatableContext +{ + /// + /// Gets or sets the validation context used for validating objects that implement or have . + /// This context provides access to service provider and other validation metadata. + /// + public ValidationContext? ValidationContext { get; set; } + + /// + /// Gets or sets the prefix used to identify the current object being validated in a complex object graph. + /// This is used to build property paths in validation error messages (e.g., "Customer.Address.Street"). + /// + public string Prefix { get; set; } = string.Empty; + + /// + /// Gets or sets the validation options that control validation behavior, + /// including validation depth limits and resolver registration. + /// + public required ValidationOptions ValidationOptions { get; set; } + + /// + /// Gets or sets the dictionary of validation errors collected during validation. + /// Keys are property names or paths, and values are arrays of error messages. + /// This dictionary is lazily initialized when the first validation error is added. + /// + public Dictionary? ValidationErrors { get; set; } + + /// + /// Gets or sets the current depth in the validation hierarchy. + /// This is used to prevent stack overflows from circular references. + /// + public int CurrentDepth { get; set; } + + internal void AddValidationError(string key, string[] error) + { + ValidationErrors ??= []; + + ValidationErrors[key] = error; + } + + internal void AddOrExtendValidationErrors(string key, string[] errors) + { + ValidationErrors ??= []; + + if (ValidationErrors.TryGetValue(key, out var existingErrors)) + { + ValidationErrors[key] = new string[existingErrors.Length + errors.Length]; + existingErrors.CopyTo(ValidationErrors[key], 0); + errors.CopyTo(ValidationErrors[key], existingErrors.Length); + } + else + { + ValidationErrors[key] = errors; + } + } + + internal void AddOrExtendValidationError(string key, string error) + { + ValidationErrors ??= []; + + if (ValidationErrors.TryGetValue(key, out var existingErrors) && !existingErrors.Contains(error)) + { + ValidationErrors[key] = [.. existingErrors, error]; + } + else + { + ValidationErrors[key] = [error]; + } + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidatableParameterInfo.cs b/src/Http/Http.Abstractions/src/Validation/ValidatableParameterInfo.cs new file mode 100644 index 000000000000..5925f49998b6 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidatableParameterInfo.cs @@ -0,0 +1,142 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.ComponentModel.DataAnnotations; +using System.Diagnostics; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Contains validation information for a parameter. +/// +public abstract class ValidatableParameterInfo : IValidatableInfo +{ + private RequiredAttribute? _requiredAttribute; + + /// + /// Creates a new instance of . + /// + /// The associated with the parameter. + /// The parameter name. + /// The display name for the parameter. + public ValidatableParameterInfo( + Type parameterType, + string name, + string displayName) + { + ParameterType = parameterType; + Name = name; + DisplayName = displayName; + } + + /// + /// Gets the parameter type. + /// + internal Type ParameterType { get; } + + /// + /// Gets the parameter name. + /// + internal string Name { get; } + + /// + /// Gets the display name for the parameter. + /// + internal string DisplayName { get; } + + /// + /// Gets the validation attributes for this parameter. + /// + /// An array of validation attributes to apply to this parameter. + protected abstract ValidationAttribute[] GetValidationAttributes(); + + /// + /// Validates the parameter value. + /// + /// The value to validate. + /// The context for the validation. + /// + /// A task representing the asynchronous operation. + /// + /// If the parameter is a collection, each item in the collection will be validated. + /// If the parameter is not a collection but has a validatable type, the single value will be validated. + /// + public virtual async ValueTask ValidateAsync(object? value, ValidatableContext context, CancellationToken cancellationToken) + { + Debug.Assert(context.ValidationContext is not null); + + // Skip validation if value is null and parameter is optional + if (value == null && ParameterType.IsNullable()) + { + return; + } + + context.ValidationContext.DisplayName = DisplayName; + context.ValidationContext.MemberName = Name; + + var validationAttributes = GetValidationAttributes(); + + if (_requiredAttribute is not null && validationAttributes.TryGetRequiredAttribute(out _requiredAttribute)) + { + var result = _requiredAttribute.GetValidationResult(value, context.ValidationContext); + + if (result is not null && result != ValidationResult.Success) + { + var key = string.IsNullOrEmpty(context.Prefix) ? Name : $"{context.Prefix}.{Name}"; + context.AddValidationError(key, [result.ErrorMessage!]); + return; + } + } + + // Validate against validation attributes + for (var i = 0; i < validationAttributes.Length; i++) + { + var attribute = validationAttributes[i]; + try + { + var result = attribute.GetValidationResult(value, context.ValidationContext); + if (result is not null && result != ValidationResult.Success) + { + var key = string.IsNullOrEmpty(context.Prefix) ? Name : $"{context.Prefix}.{Name}"; + context.AddOrExtendValidationErrors(key, [result!.ErrorMessage!]); + } + } + catch (Exception ex) + { + var key = string.IsNullOrEmpty(context.Prefix) ? Name : $"{context.Prefix}.{Name}"; + context.AddValidationError(key, [ex.Message]); + } + } + + // If the parameter is a collection, validate each item + if (ParameterType.IsEnumerable() && value is IEnumerable enumerable) + { + var index = 0; + foreach (var item in enumerable) + { + if (item != null) + { + var itemPrefix = string.IsNullOrEmpty(context.Prefix) + ? $"{Name}[{index}]" + : $"{context.Prefix}.{Name}[{index}]"; + + if (context.ValidationOptions.TryGetValidatableTypeInfo(item.GetType(), out var validatableType)) + { + await validatableType.ValidateAsync(item, context, cancellationToken); + } + } + index++; + } + } + // If not enumerable, validate the single value + else if (value != null) + { + var valueType = value.GetType(); + if (context.ValidationOptions.TryGetValidatableTypeInfo(valueType, out var validatableType)) + { + await validatableType.ValidateAsync(value, context, cancellationToken); + } + } + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidatablePropertyInfo.cs b/src/Http/Http.Abstractions/src/Validation/ValidatablePropertyInfo.cs new file mode 100644 index 000000000000..0f9e7702e4f9 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidatablePropertyInfo.cs @@ -0,0 +1,194 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Contains validation information for a member of a type. +/// +public abstract class ValidatablePropertyInfo : IValidatableInfo +{ + /// + /// Creates a new instance of . + /// + public ValidatablePropertyInfo( + [param: DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] + Type declaringType, + Type propertyType, + string name, + string displayName) + { + DeclaringType = declaringType; + PropertyType = propertyType; + Name = name; + DisplayName = displayName; + } + + /// + /// Gets the member type. + /// + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicProperties)] + internal Type DeclaringType { get; } + + /// + /// Gets the member type. + /// + internal Type PropertyType { get; } + + /// + /// Gets the member name. + /// + internal string Name { get; } + + /// + /// Gets the display name for the member as designated by the . + /// + internal string DisplayName { get; } + + /// + /// Gets whether the member is enumerable. + /// + internal bool IsEnumerable { get; } + + /// + /// Gets whether the member is nullable. + /// + internal bool IsNullable { get; } + + /// + /// Gets whether the member is annotated with the . + /// + public bool IsRequired { get; } + + /// + /// Gets the validation attributes for this member. + /// + /// An array of validation attributes to apply to this member. + protected abstract ValidationAttribute[] GetValidationAttributes(); + + /// + /// Validates the member's value. + /// + /// The object containing the member to validate. + /// The context for the validation. + /// + /// A task representing the asynchronous operation. + public virtual async ValueTask ValidateAsync(object? value, ValidatableContext context, CancellationToken cancellationToken) + { + Debug.Assert(context.ValidationContext is not null); + + var property = DeclaringType.GetProperty(Name)!; + var propertyValue = property.GetValue(value); + var validationAttributes = GetValidationAttributes(); + + // Calculate and save the current path + var originalPrefix = context.Prefix; + if (string.IsNullOrEmpty(originalPrefix)) + { + context.Prefix = Name; + } + else + { + context.Prefix = $"{originalPrefix}.{Name}"; + } + + context.ValidationContext.DisplayName = DisplayName; + context.ValidationContext.MemberName = Name; + + // Check required attribute first + if (validationAttributes.TryGetRequiredAttribute(out var requiredAttribute)) + { + var result = requiredAttribute.GetValidationResult(propertyValue, context.ValidationContext); + + if (result is not null && result != ValidationResult.Success) + { + context.AddValidationError(context.Prefix, [result!.ErrorMessage!]); + context.Prefix = originalPrefix; // Restore prefix + return; + } + } + + // Validate any other attributes + ValidateValue(propertyValue, context.Prefix, validationAttributes); + + // Check if we've reached the maximum depth before validating complex properties + if (context.CurrentDepth >= context.ValidationOptions.MaxDepth) + { + throw new InvalidOperationException( + $"Maximum validation depth of {context.ValidationOptions.MaxDepth} exceeded at '{context.Prefix}'. " + + "This is likely caused by a circular reference in the object graph. " + + "Consider increasing the MaxDepth in ValidationOptions if deeper validation is required."); + } + + // Increment depth counter + context.CurrentDepth++; + + try + { + // Handle enumerable values + if (PropertyType.IsEnumerable() && propertyValue is System.Collections.IEnumerable enumerable) + { + var index = 0; + var currentPrefix = context.Prefix; + + foreach (var item in enumerable) + { + context.Prefix = $"{currentPrefix}[{index}]"; + + if (item != null) + { + var itemType = item.GetType(); + if (context.ValidationOptions.TryGetValidatableTypeInfo(itemType, out var validatableType)) + { + await validatableType.ValidateAsync(item, context, cancellationToken); + } + } + + index++; + } + + // Restore prefix to the property name before validating the next item + context.Prefix = currentPrefix; + } + else if (propertyValue != null) + { + // Validate as a complex object + var valueType = propertyValue.GetType(); + if (context.ValidationOptions.TryGetValidatableTypeInfo(valueType, out var validatableType)) + { + await validatableType.ValidateAsync(propertyValue, context, cancellationToken); + } + } + } + finally + { + // Always decrement the depth counter and restore prefix + context.CurrentDepth--; + context.Prefix = originalPrefix; + } + + void ValidateValue(object? val, string errorPrefix, ValidationAttribute[] validationAttributes) + { + for (var i = 0; i < validationAttributes.Length; i++) + { + var attribute = validationAttributes[i]; + try + { + var result = attribute.GetValidationResult(val, context.ValidationContext); + if (result is not null && result != ValidationResult.Success) + { + context.AddOrExtendValidationErrors(errorPrefix.TrimStart('.'), [result.ErrorMessage!]); + } + } + catch (Exception ex) + { + context.AddOrExtendValidationErrors(errorPrefix.TrimStart('.'), [ex.Message]); + } + } + } + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidatableTypeAttribute.cs b/src/Http/Http.Abstractions/src/Validation/ValidatableTypeAttribute.cs new file mode 100644 index 000000000000..0ea382c59a55 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidatableTypeAttribute.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Indicates that a type is validatable to support discovery by the +/// validations generator. +/// +[AttributeUsage(AttributeTargets.Class)] +public sealed class ValidatableTypeAttribute : Attribute +{ +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidatableTypeInfo.cs b/src/Http/Http.Abstractions/src/Validation/ValidatableTypeInfo.cs new file mode 100644 index 000000000000..8eda099801c8 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidatableTypeInfo.cs @@ -0,0 +1,134 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Contains validation information for a type. +/// +public abstract class ValidatableTypeInfo : IValidatableInfo +{ + private readonly int _membersCount; + private readonly IEnumerable _subTypes; + + /// + /// Creates a new instance of . + /// + /// The type being validated. + /// The members that can be validated. + public ValidatableTypeInfo( + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.Interfaces)] Type type, + IReadOnlyList members) + { + Type = type; + Members = members; + _membersCount = members.Count; + _subTypes = type.GetAllImplementedTypes(); + } + + /// + /// The type being validated. + /// + internal Type Type { get; } + + /// + /// The members that can be validated. + /// + internal IReadOnlyList Members { get; } + + /// + /// Validates the specified value. + /// + /// The value to validate. + /// The validation context. + /// + public virtual async ValueTask ValidateAsync(object? value, ValidatableContext context, CancellationToken cancellationToken) + { + Debug.Assert(context.ValidationContext is not null); + if (value == null) + { + return; + } + + // Check if we've exceeded the maximum depth + if (context.CurrentDepth >= context.ValidationOptions.MaxDepth) + { + throw new InvalidOperationException( + $"Maximum validation depth of {context.ValidationOptions.MaxDepth} exceeded at '{context.Prefix}'. " + + "This is likely caused by a circular reference in the object graph. " + + "Consider increasing the MaxDepth in ValidationOptions if deeper validation is required."); + } + + try + { + var actualType = value.GetType(); + var originalPrefix = context.Prefix; + + // First validate members + for (var i = 0; i < _membersCount; i++) + { + await Members[i].ValidateAsync(value, context, cancellationToken); + context.Prefix = originalPrefix; + } + + // Then validate sub-types if any + foreach (var subType in _subTypes) + { + // Check if the actual type is assignable to the sub-type + // and validate it if it is + if (subType.IsAssignableFrom(actualType)) + { + if (context.ValidationOptions.TryGetValidatableTypeInfo(subType, out var subTypeInfo)) + { + await subTypeInfo.ValidateAsync(value, context, cancellationToken); + context.Prefix = originalPrefix; + } + } + } + + // Finally validate IValidatableObject if implemented + if (Type.ImplementsInterface(typeof(IValidatableObject)) && value is IValidatableObject validatable) + { + // Important: Set the DisplayName to the type name for top-level validations + // and restore the original validation context properties + var originalDisplayName = context.ValidationContext.DisplayName; + var originalMemberName = context.ValidationContext.MemberName; + + // Set the display name to the class name for IValidatableObject validation + context.ValidationContext.DisplayName = Type.Name; + context.ValidationContext.MemberName = null; + + var validationResults = validatable.Validate(context.ValidationContext); + foreach (var validationResult in validationResults) + { + if (validationResult != ValidationResult.Success) + { + var memberName = validationResult.MemberNames.First(); + var key = string.IsNullOrEmpty(originalPrefix) ? + memberName : + $"{originalPrefix}.{memberName}"; + + context.AddOrExtendValidationError(key, validationResult.ErrorMessage!); + } + } + + // Restore the original validation context properties + context.ValidationContext.DisplayName = originalDisplayName; + context.ValidationContext.MemberName = originalMemberName; + } + + // Always restore original prefix + context.Prefix = originalPrefix; + } + finally + { + // Decrement depth when validation completes + context.CurrentDepth--; + } + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidationOptions.cs b/src/Http/Http.Abstractions/src/Validation/ValidationOptions.cs new file mode 100644 index 000000000000..d27ba37eaf13 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidationOptions.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace Microsoft.AspNetCore.Http.Validation; + +/// +/// Provides configuration options for the validation system. +/// +public class ValidationOptions +{ + /// + /// Gets the list of resolvers that provide validation metadata for types and parameters. + /// Resolvers are processed in order, with the first resolver providing a non-null result being used. + /// + /// + /// Source-generated resolvers are typically inserted at the beginning of this list + /// to ensure they are checked before any runtime-based resolvers. + /// + public IList Resolvers { get; } = []; + + /// + /// Gets or sets the maximum depth for validation of nested objects. + /// This prevents stack overflows from circular references or extremely deep object graphs. + /// Default value is 32. + /// + public int MaxDepth { get; set; } = 32; + + /// + /// Attempts to get validation information for the specified type. + /// + /// The type to get validation information for. + /// When this method returns, contains the validation information for the specified type, + /// if the type was found; otherwise, null. + /// true if validation information was found for the specified type; otherwise, false. + public bool TryGetValidatableTypeInfo(Type type, [NotNullWhen(true)] out IValidatableInfo? validatableTypeInfo) + { + foreach (var resolver in Resolvers) + { + if (resolver.TryGetValidatableTypeInfo(type, out validatableTypeInfo)) + { + return true; + } + } + + validatableTypeInfo = null; + return false; + } + + /// + /// Attempts to get validation information for the specified parameter. + /// + /// The parameter to get validation information for. + /// When this method returns, contains the validation information for the specified parameter, + /// if validation information was found; otherwise, null. + /// true if validation information was found for the specified parameter; otherwise, false. + public bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + foreach (var resolver in Resolvers) + { + if (resolver.TryGetValidatableParameterInfo(parameterInfo, out validatableInfo)) + { + return true; + } + } + + validatableInfo = null; + return false; + } +} diff --git a/src/Http/Http.Abstractions/src/Validation/ValidationServiceCollectionExtensions.cs b/src/Http/Http.Abstractions/src/Validation/ValidationServiceCollectionExtensions.cs new file mode 100644 index 000000000000..77a128842ea4 --- /dev/null +++ b/src/Http/Http.Abstractions/src/Validation/ValidationServiceCollectionExtensions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Http.Validation; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extension methods for adding validation services. +/// +public static class ValidationServiceCollectionExtensions +{ + /// + /// Adds the validation services to the specified . + /// + /// The to add the services to. + /// An optional action to configure the . + /// The for chaining. + public static IServiceCollection AddValidation(this IServiceCollection services, Action? configureOptions = null) + { + services.Configure(options => + { + if (configureOptions is not null) + { + configureOptions(options); + } + // Support ParameterInfo resolution at runtime + options.Resolvers.Add(new RuntimeValidatableParameterInfoResolver()); + }); + return services; + } +} diff --git a/src/Http/Http.Abstractions/test/Validation/ValidatableInfoResolverTests.cs b/src/Http/Http.Abstractions/test/Validation/ValidatableInfoResolverTests.cs new file mode 100644 index 000000000000..6960162a13c5 --- /dev/null +++ b/src/Http/Http.Abstractions/test/Validation/ValidatableInfoResolverTests.cs @@ -0,0 +1,221 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Reflection; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Moq; + +namespace Microsoft.AspNetCore.Http.Validation.Tests; + +public class ValidatableInfoResolverTests +{ + public delegate void TryGetValidatableTypeInfoCallback(Type type, out IValidatableInfo? validatableInfo); + public delegate void TryGetValidatableParameterInfoCallback(ParameterInfo parameter, out IValidatableInfo? validatableInfo); + + [Fact] + public void GetValidatableTypeInfo_ReturnsNull_ForNonValidatableType() + { + // Arrange + var resolver = new Mock(); + IValidatableInfo? validatableInfo = null; + resolver.Setup(r => r.TryGetValidatableTypeInfo(It.IsAny(), out validatableInfo)).Returns(false); + + // Act + var result = resolver.Object.TryGetValidatableTypeInfo(typeof(NonValidatableType), out validatableInfo); + + // Assert + Assert.False(result); + Assert.Null(validatableInfo); + } + + [Fact] + public void GetValidatableTypeInfo_ReturnsTypeInfo_ForValidatableType() + { + // Arrange + var mockTypeInfo = new Mock( + typeof(ValidatableType), + Array.Empty()).Object; + + var resolver = new Mock(); + IValidatableInfo? validatableInfo = null; + resolver + .Setup(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out validatableInfo)) + .Callback(new TryGetValidatableTypeInfoCallback((t, out info) => + { + info = mockTypeInfo; // Set the out parameter to our mock + })) + .Returns(true); + + // Act + var result = resolver.Object.TryGetValidatableTypeInfo(typeof(ValidatableType), out validatableInfo); + + // Assert + Assert.True(result); + var validatableTypeInfo = Assert.IsAssignableFrom(validatableInfo); + Assert.Equal(typeof(ValidatableType), validatableTypeInfo.Type); + } + + [Fact] + public void GetValidatableParameterInfo_ReturnsNull_ForNonValidatableParameter() + { + // Arrange + var method = typeof(TestMethods).GetMethod(nameof(TestMethods.MethodWithNonValidatableParam))!; + var parameter = method.GetParameters()[0]; + + var resolver = new Mock(); + IValidatableInfo? validatableInfo = null; + resolver.Setup(r => r.TryGetValidatableParameterInfo(It.IsAny(), out validatableInfo)).Returns(false); + + // Act + var result = resolver.Object.TryGetValidatableParameterInfo(parameter, out validatableInfo); + + // Assert + Assert.False(result); + } + + [Fact] + public void GetValidatableParameterInfo_ReturnsParameterInfo_ForValidatableParameter() + { + // Arrange + var method = typeof(TestMethods).GetMethod(nameof(TestMethods.MethodWithValidatableParam))!; + var parameter = method.GetParameters()[0]; + + var mockParamInfo = new Mock( + typeof(string), + "model", + "model").Object; + + var resolver = new Mock(); + + // Setup using the same pattern as in the type info test + resolver.Setup(r => r.TryGetValidatableParameterInfo(parameter, out It.Ref.IsAny)) + .Callback(new TryGetValidatableParameterInfoCallback((ParameterInfo p, out IValidatableInfo? info) => + { + info = mockParamInfo; // Set the out parameter to our mock + })) + .Returns(true); + + // Act + var result = resolver.Object.TryGetValidatableParameterInfo(parameter, out var validatableInfo); + + // Assert + Assert.True(result); + var validatableParamInfo = Assert.IsAssignableFrom(validatableInfo); + Assert.Equal("model", validatableParamInfo.Name); + } + + [Fact] + public void ResolversChain_ProcessesInCorrectOrder() + { + // Arrange + var services = new ServiceCollection(); + + var resolver1 = new Mock(); + var resolver2 = new Mock(); + var resolver3 = new Mock(); + + // Create the object that will be returned by resolver2 + var mockTypeInfo = new Mock(typeof(ValidatableType), Array.Empty()).Object; + + // Setup resolver1 to return false (doesn't handle this type) + resolver1 + .Setup(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out It.Ref.IsAny)) + .Callback(new TryGetValidatableTypeInfoCallback((Type t, out IValidatableInfo? info) => + { + info = null; + })) + .Returns(false); + + // Setup resolver2 to return true and set the mock type info + resolver2 + .Setup(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out It.Ref.IsAny)) + .Callback(new TryGetValidatableTypeInfoCallback((Type t, out IValidatableInfo? info) => + { + info = mockTypeInfo; + })) + .Returns(true); + + services.AddValidation(Options => + { + Options.Resolvers.Add(resolver1.Object); + Options.Resolvers.Add(resolver2.Object); + Options.Resolvers.Add(resolver3.Object); + }); + + var serviceProvider = services.BuildServiceProvider(); + var validationOptions = serviceProvider.GetRequiredService>().Value; + + // Act + var result = validationOptions.TryGetValidatableTypeInfo(typeof(ValidatableType), out var validatableInfo); + + // Assert + Assert.True(result); + Assert.NotNull(validatableInfo); + Assert.Equal(typeof(ValidatableType), ((ValidatableTypeInfo)validatableInfo).Type); + + // Verify that resolvers were called in the expected order + resolver1.Verify(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out It.Ref.IsAny), Times.Once); + resolver2.Verify(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out It.Ref.IsAny), Times.Once); + resolver3.Verify(r => r.TryGetValidatableTypeInfo(typeof(ValidatableType), out It.Ref.IsAny), Times.Never); + } + + // Test types + private class NonValidatableType { } + + [ValidatableType] + private class ValidatableType + { + [Required] + public string Name { get; set; } = ""; + } + + private static class TestMethods + { + public static void MethodWithNonValidatableParam(NonValidatableType param) { } + public static void MethodWithValidatableParam(ValidatableType model) { } + } + + // Test implementations + private class TestValidatablePropertyInfo : ValidatablePropertyInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public TestValidatablePropertyInfo( + Type containingType, + Type propertyType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + private class TestValidatableParameterInfo : ValidatableParameterInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public TestValidatableParameterInfo( + Type parameterType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + : base(parameterType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + private class TestValidatableTypeInfo( + Type type, + ValidatablePropertyInfo[] members) : ValidatableTypeInfo(type, members) + { + } +} diff --git a/src/Http/Http.Abstractions/test/Validation/ValidatableParameterInfoTests.cs b/src/Http/Http.Abstractions/test/Validation/ValidatableParameterInfoTests.cs new file mode 100644 index 000000000000..587a0e64d777 --- /dev/null +++ b/src/Http/Http.Abstractions/test/Validation/ValidatableParameterInfoTests.cs @@ -0,0 +1,405 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Http.Validation.Tests; + +public class ValidatableParameterInfoTests +{ + [Fact] + public async Task Validate_RequiredParameter_AddsErrorWhenNull() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(string), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: [new RequiredAttribute()]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync(null, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("testParam", error.Key); + Assert.Equal("The Test Parameter field is required.", error.Value.First()); + } + + [Fact] + public async Task Validate_SkipsValidation_WhenNullAndNotRequired() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(string), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: [new StringLengthAttribute(10)]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync(null, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.Null(errors); // No errors added + } + + [Fact] + public async Task Validate_WithRangeAttribute_ValidatesCorrectly() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(int), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: [new RangeAttribute(10, 100)]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync(5, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("testParam", error.Key); + Assert.Equal("The field Test Parameter must be between 10 and 100.", error.Value.First()); + } + + [Fact] + public async Task Validate_WithDisplayNameAttribute_UsesDisplayNameInErrorMessage() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(string), + name: "testParam", + displayName: "Custom Display Name", + validationAttributes: [new RequiredAttribute()]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync(null, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("testParam", error.Key); + // The error message should use the display name + Assert.Equal("The Custom Display Name field is required.", error.Value.First()); + } + + [Fact] + public async Task Validate_WhenValidatableTypeHasErrors_AddsNestedErrors() + { + // Arrange + var personTypeInfo = new TestValidatableTypeInfo( + typeof(Person), + [ + new TestValidatablePropertyInfo( + typeof(Person), + typeof(string), + "Name", + "Name", + [new RequiredAttribute()]) + ]); + + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(Person), + name: "person", + displayName: "Person", + validationAttributes: []); + + var typeMapping = new Dictionary + { + { typeof(Person), personTypeInfo } + }; + + var context = CreateValidatableContext(typeMapping); + var person = new Person(); // Name is null, so should fail validation + + // Act + await paramInfo.ValidateAsync(person, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("Name", error.Key); + Assert.Equal("The Name field is required.", error.Value[0]); + } + + [Fact] + public async Task Validate_WithEnumerableOfValidatableType_ValidatesEachItem() + { + // Arrange + var personTypeInfo = new TestValidatableTypeInfo( + typeof(Person), + [ + new TestValidatablePropertyInfo( + typeof(Person), + typeof(string), + "Name", + "Name", + [new RequiredAttribute()]) + ]); + + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(IEnumerable), + name: "people", + displayName: "People", + validationAttributes: []); + + var typeMapping = new Dictionary + { + { typeof(Person), personTypeInfo } + }; + + var context = CreateValidatableContext(typeMapping); + var people = new List + { + new() { Name = "Valid" }, + new() // Name is null, should fail + }; + + // Act + await paramInfo.ValidateAsync(people, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("Name", error.Key); + Assert.Equal("The Name field is required.", error.Value[0]); + } + + [Fact] + public async Task Validate_MultipleErrorsOnSameParameter_CollectsAllErrors() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(int), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: + [ + new RangeAttribute(10, 100) { ErrorMessage = "Range error" }, + new CustomTestValidationAttribute { ErrorMessage = "Custom error" } + ]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync(5, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("testParam", error.Key); + Assert.Collection(error.Value, + e => Assert.Equal("Range error", e), + e => Assert.Equal("Custom error", e)); + } + + [Fact] + public async Task Validate_WithContextPrefix_AddsErrorsWithCorrectPrefix() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(int), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: [new RangeAttribute(10, 100)]); + + var context = CreateValidatableContext(); + context.Prefix = "parent"; + + // Act + await paramInfo.ValidateAsync(5, context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("parent.testParam", error.Key); + Assert.Equal("The field Test Parameter must be between 10 and 100.", error.Value.First()); + } + + [Fact] + public async Task Validate_ExceptionDuringValidation_CapturesExceptionAsError() + { + // Arrange + var paramInfo = CreateTestParameterInfo( + parameterType: typeof(string), + name: "testParam", + displayName: "Test Parameter", + validationAttributes: [new ThrowingValidationAttribute()]); + + var context = CreateValidatableContext(); + + // Act + await paramInfo.ValidateAsync("test", context, default); + + // Assert + var errors = context.ValidationErrors; + Assert.NotNull(errors); + var error = Assert.Single(errors); + Assert.Equal("testParam", error.Key); + Assert.Equal("Test exception", error.Value.First()); + } + + private TestValidatableParameterInfo CreateTestParameterInfo( + Type parameterType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + { + return new TestValidatableParameterInfo( + parameterType, + name, + displayName, + validationAttributes); + } + + private ValidatableContext CreateValidatableContext( + Dictionary? typeMapping = null) + { + var serviceProvider = new ServiceCollection().BuildServiceProvider(); + var validationContext = new ValidationContext(new object(), serviceProvider, null); + + return new ValidatableContext + { + ValidationContext = validationContext, + ValidationOptions = new TestValidationOptions(typeMapping ?? new Dictionary()) + }; + } + + private class TestValidatableParameterInfo : ValidatableParameterInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public TestValidatableParameterInfo( + Type parameterType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + : base(parameterType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + private class TestValidatablePropertyInfo : ValidatablePropertyInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public TestValidatablePropertyInfo( + Type containingType, + Type propertyType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + private class TestValidatableTypeInfo( + Type type, + ValidatablePropertyInfo[] members) : ValidatableTypeInfo(type, members) + { + } + + private class TestValidationOptions : ValidationOptions + { + public TestValidationOptions(Dictionary typeInfoMappings) + { + // Create a custom resolver that uses the dictionary + var resolver = new DictionaryBasedResolver(typeInfoMappings); + + // Add it to the resolvers collection + Resolvers.Add(resolver); + } + + // Private resolver implementation that uses a dictionary lookup + private class DictionaryBasedResolver : IValidatableInfoResolver + { + private readonly Dictionary _typeInfoMappings; + + public DictionaryBasedResolver(Dictionary typeInfoMappings) + { + _typeInfoMappings = typeInfoMappings; + } + + public ValidatableTypeInfo? TryGetValidatableTypeInfo(Type type) + { + _typeInfoMappings.TryGetValue(type, out var info); + return info; + } + + public ValidatableParameterInfo? GetValidatableParameterInfo(ParameterInfo parameterInfo) + { + // Not implemented in the test + return null; + } + + public bool TryGetValidatableTypeInfo(Type type, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + if (_typeInfoMappings.TryGetValue(type, out var validatableTypeInfo)) + { + validatableInfo = validatableTypeInfo; + return true; + } + validatableInfo = null; + return false; + } + + public bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + } + } + + // Test data classes and validation attributes + + private class Person + { + public string? Name { get; set; } + } + + private class CustomTestValidationAttribute : ValidationAttribute + { + public override bool IsValid(object? value) + { + // Always fail for testing + return false; + } + } + + private class ThrowingValidationAttribute : ValidationAttribute + { + public override bool IsValid(object? value) + { + throw new InvalidOperationException("Test exception"); + } + } +} diff --git a/src/Http/Http.Abstractions/test/Validation/ValidatableTypeInfoTests.cs b/src/Http/Http.Abstractions/test/Validation/ValidatableTypeInfoTests.cs new file mode 100644 index 000000000000..764e6cd26f50 --- /dev/null +++ b/src/Http/Http.Abstractions/test/Validation/ValidatableTypeInfoTests.cs @@ -0,0 +1,676 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace Microsoft.AspNetCore.Http.Validation.Tests; + +public class ValidatableTypeInfoTests +{ + [Fact] + public async Task Validate_ValidatesComplexType_WithNestedProperties() + { + // Arrange + var personType = new TestValidatableTypeInfo( + typeof(Person), + [ + CreatePropertyInfo(typeof(Person), typeof(string), "Name", "Name", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(Person), typeof(int), "Age", "Age", + [new RangeAttribute(0, 120)]), + CreatePropertyInfo(typeof(Person), typeof(Address), "Address", "Address", + []) + ]); + + var addressType = new TestValidatableTypeInfo( + typeof(Address), + [ + CreatePropertyInfo(typeof(Address), typeof(string), "Street", "Street", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(Address), typeof(string), "City", "City", + [new RequiredAttribute()]) + ]); + + var validationOptions = new TestValidationOptions(new Dictionary + { + { typeof(Person), personType }, + { typeof(Address), addressType } + }); + + var context = new ValidatableContext + { + ValidationOptions = validationOptions, + }; + + var personWithMissingRequiredFields = new Person + { + Age = 150, // Invalid age + Address = new Address() // Missing required City and Street + }; + context.ValidationContext = new ValidationContext(personWithMissingRequiredFields); + + // Act + await personType.ValidateAsync(personWithMissingRequiredFields, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + Assert.Collection(context.ValidationErrors, + kvp => + { + Assert.Equal("Name", kvp.Key); + Assert.Equal("The Name field is required.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Age", kvp.Key); + Assert.Equal("The field Age must be between 0 and 120.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Address.Street", kvp.Key); + Assert.Equal("The Street field is required.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Address.City", kvp.Key); + Assert.Equal("The City field is required.", kvp.Value.First()); + }); + } + + [Fact] + public async Task Validate_HandlesIValidatableObject_Implementation() + { + // Arrange + var employeeType = new TestValidatableTypeInfo( + typeof(Employee), + [ + CreatePropertyInfo(typeof(Employee), typeof(string), "Name", "Name", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(Employee), typeof(string), "Department", "Department", + []), + CreatePropertyInfo(typeof(Employee), typeof(decimal), "Salary", "Salary", + []) + ]); + + var context = new ValidatableContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(Employee), employeeType } + }) + }; + + var employee = new Employee + { + Name = "John Doe", + Department = "IT", + Salary = -5000 // Negative salary will trigger IValidatableObject validation + }; + context.ValidationContext = new ValidationContext(employee); + + // Act + await employeeType.ValidateAsync(employee, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + var error = Assert.Single(context.ValidationErrors); + Assert.Equal("Salary", error.Key); + Assert.Equal("Salary must be a positive value.", error.Value.First()); + } + + [Fact] + public async Task Validate_HandlesPolymorphicTypes_WithSubtypes() + { + // Arrange + var baseType = new TestValidatableTypeInfo( + typeof(Vehicle), + [ + CreatePropertyInfo(typeof(Vehicle), typeof(string), "Make", "Make", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(Vehicle), typeof(string), "Model", "Model", + [new RequiredAttribute()]) + ]); + + var derivedType = new TestValidatableTypeInfo( + typeof(Car), + [ + CreatePropertyInfo(typeof(Car), typeof(int), "Doors", "Doors", + [new RangeAttribute(2, 5)]) + ]); + + var context = new ValidatableContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(Vehicle), baseType }, + { typeof(Car), derivedType } + }) + }; + + var car = new Car + { + // Missing Make and Model (required in base type) + Doors = 7 // Invalid number of doors + }; + context.ValidationContext = new ValidationContext(car); + + // Act + await derivedType.ValidateAsync(car, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + Assert.Collection(context.ValidationErrors, + kvp => + { + Assert.Equal("Doors", kvp.Key); + Assert.Equal("The field Doors must be between 2 and 5.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Make", kvp.Key); + Assert.Equal("The Make field is required.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Model", kvp.Key); + Assert.Equal("The Model field is required.", kvp.Value.First()); + }); + } + + [Fact] + public async Task Validate_HandlesCollections_OfValidatableTypes() + { + // Arrange + var itemType = new TestValidatableTypeInfo( + typeof(OrderItem), + [ + CreatePropertyInfo(typeof(OrderItem), typeof(string), "ProductName", "ProductName", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(OrderItem), typeof(int), "Quantity", "Quantity", + [new RangeAttribute(1, 100)]) + ]); + + var orderType = new TestValidatableTypeInfo( + typeof(Order), + [ + CreatePropertyInfo(typeof(Order), typeof(string), "OrderNumber", "OrderNumber", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(Order), typeof(List), "Items", "Items", + []) + ]); + + var context = new ValidatableContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(OrderItem), itemType }, + { typeof(Order), orderType } + }) + }; + + var order = new Order + { + OrderNumber = "ORD-12345", + Items = + [ + new OrderItem { ProductName = "Valid Product", Quantity = 5 }, + new OrderItem { /* Missing ProductName (required) */ Quantity = 0 /* Invalid quantity */ }, + new OrderItem { ProductName = "Another Product", Quantity = 200 /* Invalid quantity */ } + ] + }; + context.ValidationContext = new ValidationContext(order); + + // Act + await orderType.ValidateAsync(order, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + Assert.Collection(context.ValidationErrors, + kvp => + { + Assert.Equal("Items[1].ProductName", kvp.Key); + Assert.Equal("The ProductName field is required.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Items[1].Quantity", kvp.Key); + Assert.Equal("The field Quantity must be between 1 and 100.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("Items[2].Quantity", kvp.Key); + Assert.Equal("The field Quantity must be between 1 and 100.", kvp.Value.First()); + }); + } + + [Fact] + public async Task Validate_HandlesNullValues_Appropriately() + { + // Arrange + var personType = new TestValidatableTypeInfo( + typeof(Person), + [ + CreatePropertyInfo(typeof(Person), typeof(string), "Name", "Name", + []), + CreatePropertyInfo(typeof(Person), typeof(Address), "Address", "Address", + []) + ]); + + var context = new ValidatableContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(Person), personType } + }) + }; + + var person = new Person + { + Name = null, + Address = null + }; + context.ValidationContext = new ValidationContext(person); + + // Act + await personType.ValidateAsync(person, context, default); + + // Assert + Assert.Null(context.ValidationErrors); // No validation errors for nullable properties with null values + } + + [Fact] + public async Task Validate_RespectsMaxDepthOption_ForCircularReferences() + { + // Arrange + // Create a type that can contain itself (circular reference) + var nodeType = new TestValidatableTypeInfo( + typeof(TreeNode), + [ + CreatePropertyInfo(typeof(TreeNode), typeof(string), "Name", "Name", + [new RequiredAttribute()]), + CreatePropertyInfo(typeof(TreeNode), typeof(TreeNode), "Parent", "Parent", + []), + CreatePropertyInfo(typeof(TreeNode), typeof(List), "Children", "Children", + []) + ]); + + // Create a validation options with a small max depth + var validationOptions = new TestValidationOptions(new Dictionary + { + { typeof(TreeNode), nodeType } + }); + validationOptions.MaxDepth = 3; // Set a small max depth to trigger the limit + + var context = new ValidatableContext + { + ValidationOptions = validationOptions, + ValidationErrors = [] + }; + + // Create a deep tree with circular references + var rootNode = new TreeNode { Name = "Root" }; + var level1 = new TreeNode { Name = "Level1", Parent = rootNode }; + var level2 = new TreeNode { Name = "Level2", Parent = level1 }; + var level3 = new TreeNode { Name = "Level3", Parent = level2 }; + var level4 = new TreeNode { Name = "" }; // Invalid: missing required name + var level5 = new TreeNode { Name = "" }; // Invalid but beyond max depth, should not be validated + + rootNode.Children.Add(level1); + level1.Children.Add(level2); + level2.Children.Add(level3); + level3.Children.Add(level4); + level4.Children.Add(level5); + + // Add a circular reference + level5.Children.Add(rootNode); + + context.ValidationContext = new ValidationContext(rootNode); + + // Act + Assert + var exception = await Assert.ThrowsAsync( + async () => await nodeType.ValidateAsync(rootNode, context, default)); + + Assert.NotNull(exception); + Assert.Contains("Maximum validation depth of 3 exceeded at 'Children[0].Parent.Children[0]'. This is likely caused by a circular reference in the object graph. Consider increasing the MaxDepth in ValidationOptions if deeper validation is required.", exception.Message); + } + + [Fact] + public async Task Validate_HandlesCustomValidationAttributes() + { + // Arrange + var productType = new TestValidatableTypeInfo( + typeof(Product), + [ + CreatePropertyInfo(typeof(Product), typeof(string), "SKU", "SKU", [new RequiredAttribute(), new CustomSkuValidationAttribute()]), + ]); + + var context = new ValidatableContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(Product), productType } + }) + }; + + var product = new Product { SKU = "INVALID-SKU" }; + context.ValidationContext = new ValidationContext(product); + + // Act + await productType.ValidateAsync(product, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + var error = Assert.Single(context.ValidationErrors); + Assert.Equal("SKU", error.Key); + Assert.Equal("SKU must start with 'PROD-'.", error.Value.First()); + } + + [Fact] + public async Task Validate_HandlesMultipleErrorsOnSameProperty() + { + // Arrange + var userType = new TestValidatableTypeInfo( + typeof(User), + [ + CreatePropertyInfo(typeof(User), typeof(string), "Password", "Password", + [ + new RequiredAttribute(), + new MinLengthAttribute(8) { ErrorMessage = "Password must be at least 8 characters." }, + new PasswordComplexityAttribute() + ]) + ]); + + var context = new ValidatableContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(User), userType } + }) + }; + + var user = new User { Password = "abc" }; // Too short and not complex enough + context.ValidationContext = new ValidationContext(user); + + // Act + await userType.ValidateAsync(user, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + Assert.Single(context.ValidationErrors.Keys); // Only the "Password" key + Assert.Equal(2, context.ValidationErrors["Password"].Length); // But with 2 errors + Assert.Contains("Password must be at least 8 characters.", context.ValidationErrors["Password"]); + Assert.Contains("Password must contain at least one number and one special character.", context.ValidationErrors["Password"]); + } + + [Fact] + public async Task Validate_HandlesMultiLevelInheritance() + { + // Arrange + var baseType = new TestValidatableTypeInfo( + typeof(BaseEntity), + [ + CreatePropertyInfo(typeof(BaseEntity), typeof(Guid), "Id", "Id", []) + ]); + + var intermediateType = new TestValidatableTypeInfo( + typeof(IntermediateEntity), + [ + CreatePropertyInfo(typeof(IntermediateEntity), typeof(DateTime), "CreatedAt", "CreatedAt", [new PastDateAttribute()]) + ]); + + var derivedType = new TestValidatableTypeInfo( + typeof(DerivedEntity), + [ + CreatePropertyInfo(typeof(DerivedEntity), typeof(string), "Name", "Name", [new RequiredAttribute()]) + ]); + + var context = new ValidatableContext + { + ValidationOptions = new TestValidationOptions(new Dictionary + { + { typeof(BaseEntity), baseType }, + { typeof(IntermediateEntity), intermediateType }, + { typeof(DerivedEntity), derivedType } + }) + }; + + var entity = new DerivedEntity + { + Name = "", // Invalid: required + CreatedAt = DateTime.Now.AddDays(1) // Invalid: future date + }; + context.ValidationContext = new ValidationContext(entity); + + // Act + await derivedType.ValidateAsync(entity, context, default); + + // Assert + Assert.NotNull(context.ValidationErrors); + Assert.Collection(context.ValidationErrors, + kvp => + { + Assert.Equal("Name", kvp.Key); + Assert.Equal("The Name field is required.", kvp.Value.First()); + }, + kvp => + { + Assert.Equal("CreatedAt", kvp.Key); + Assert.Equal("Date must be in the past.", kvp.Value.First()); + }); + } + + private ValidatablePropertyInfo CreatePropertyInfo( + Type containingType, + Type propertyType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + { + return new TestValidatablePropertyInfo( + containingType, + propertyType, + name, + displayName, + validationAttributes); + } + + // Test model classes + private class Person + { + public string? Name { get; set; } + public int Age { get; set; } + public Address? Address { get; set; } + } + + private class Address + { + public string? Street { get; set; } + public string? City { get; set; } + } + + private class Employee : IValidatableObject + { + public string? Name { get; set; } + public string? Department { get; set; } + public decimal Salary { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Salary < 0) + { + yield return new ValidationResult("Salary must be a positive value.", new[] { nameof(Salary) }); + } + } + } + + private class Vehicle + { + public string? Make { get; set; } + public string? Model { get; set; } + } + + private class Car : Vehicle + { + public int Doors { get; set; } + } + + private class Order + { + public string? OrderNumber { get; set; } + public List Items { get; set; } = []; + } + + private class OrderItem + { + public string? ProductName { get; set; } + public int Quantity { get; set; } + } + + private class TreeNode + { + public string Name { get; set; } = string.Empty; + public TreeNode? Parent { get; set; } + public List Children { get; set; } = []; + } + + private class Product + { + public string SKU { get; set; } = string.Empty; + } + + private class User + { + public string Password { get; set; } = string.Empty; + } + + private class BaseEntity + { + public Guid Id { get; set; } = Guid.NewGuid(); + } + + private class IntermediateEntity : BaseEntity + { + public DateTime CreatedAt { get; set; } + } + + private class DerivedEntity : IntermediateEntity + { + public string Name { get; set; } = string.Empty; + } + + private class PastDateAttribute : ValidationAttribute + { + protected override ValidationResult? IsValid(object? value, ValidationContext validationContext) + { + if (value is DateTime date && date > DateTime.Now) + { + return new ValidationResult("Date must be in the past."); + } + + return ValidationResult.Success; + } + } + + private class CustomSkuValidationAttribute : ValidationAttribute + { + protected override ValidationResult? IsValid(object? value, ValidationContext validationContext) + { + if (value is string sku && !sku.StartsWith("PROD-", StringComparison.Ordinal)) + { + return new ValidationResult("SKU must start with 'PROD-'."); + } + + return ValidationResult.Success; + } + } + + private class PasswordComplexityAttribute : ValidationAttribute + { + protected override ValidationResult? IsValid(object? value, ValidationContext validationContext) + { + if (value is string password) + { + var hasDigit = password.Any(c => char.IsDigit(c)); + var hasSpecial = password.Any(c => !char.IsLetterOrDigit(c)); + + if (!hasDigit || !hasSpecial) + { + return new ValidationResult("Password must contain at least one number and one special character."); + } + } + + return ValidationResult.Success; + } + } + + // Test implementations + private class TestValidatablePropertyInfo : ValidatablePropertyInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public TestValidatablePropertyInfo( + Type containingType, + Type propertyType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) + : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + private class TestValidatableTypeInfo : ValidatableTypeInfo + { + public TestValidatableTypeInfo( + Type type, + ValidatablePropertyInfo[] members) + : base(type, members) + { + } + } + + private class TestValidationOptions : ValidationOptions + { + public TestValidationOptions(Dictionary typeInfoMappings) + { + // Create a custom resolver that uses the dictionary + var resolver = new DictionaryBasedResolver(typeInfoMappings); + + // Add it to the resolvers collection + Resolvers.Add(resolver); + } + + // Private resolver implementation that uses a dictionary lookup + private class DictionaryBasedResolver : IValidatableInfoResolver + { + private readonly Dictionary _typeInfoMappings; + + public DictionaryBasedResolver(Dictionary typeInfoMappings) + { + _typeInfoMappings = typeInfoMappings; + } + + public bool TryGetValidatableTypeInfo(Type type, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + if (_typeInfoMappings.TryGetValue(type, out var info)) + { + validatableInfo = info; + return true; + } + validatableInfo = null; + return false; + } + + public bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, [NotNullWhen(true)] out IValidatableInfo? validatableInfo) + { + validatableInfo = null; + return false; + } + } + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Emitters/ValidationsGenerator.Emitter.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Emitters/ValidationsGenerator.Emitter.cs new file mode 100644 index 000000000000..b99518e9130d --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Emitters/ValidationsGenerator.Emitter.cs @@ -0,0 +1,500 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; +using System.Text; +using System.Linq; +using Microsoft.CodeAnalysis.CSharp; +using System.IO; +using System.Collections.Generic; +using System; +using System.Globalization; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + public static string GeneratedCodeConstructor => $@"global::System.CodeDom.Compiler.GeneratedCodeAttribute(""{typeof(ValidationsGenerator).Assembly.FullName}"", ""{typeof(ValidationsGenerator).Assembly.GetName().Version}"")"; + public static string GeneratedCodeAttribute => $"[{GeneratedCodeConstructor}]"; + + internal static void Emit(SourceProductionContext context, (InterceptableLocation? AddValidation, ImmutableArray ValidatableTypes) emitInputs) + { + var source = Emit(emitInputs.AddValidation, emitInputs.ValidatableTypes); + context.AddSource("ValidatableInfoResolver.g.cs", SourceText.From(source, Encoding.UTF8)); + } + + private static string Emit(InterceptableLocation? addValidation, ImmutableArray validatableTypes) => $$""" +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + {{GeneratedCodeAttribute}} + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + {{GeneratedCodeAttribute}} + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + private readonly global::System.ComponentModel.DataAnnotations.ValidationAttribute[] _validationAttributes; + + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName, + global::System.ComponentModel.DataAnnotations.ValidationAttribute[] validationAttributes) : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + {{GeneratedCodeAttribute}} + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + {{GeneratedCodeAttribute}} + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; +{{EmitTypeChecks(validatableTypes)}} + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; + return false; + } + +{{EmitCreateMethods(validatableTypes)}} + } + + {{GeneratedCodeAttribute}} + file static class GeneratedServiceCollectionExtensions + { + {{addValidation!.GetInterceptsLocationAttributeSyntax()}} + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + {{GeneratedCodeAttribute}} + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type AttributeType, object[] Arguments, global::System.Collections.Generic.Dictionary NamedArguments); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute? GetOrCreateValidationAttribute( + global::System.Type attributeType, + object[] arguments, + global::System.Collections.Generic.Dictionary namedArguments) + { + var key = new CacheKey(attributeType, arguments, namedArguments); + return _cache.GetOrAdd(key, static k => + { + var type = k.AttributeType; + var args = k.Arguments; + + global::System.ComponentModel.DataAnnotations.ValidationAttribute attribute; + + if (args.Length == 0) + { + attribute = type switch + { + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute) => new global::System.ComponentModel.DataAnnotations.RequiredAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.EmailAddressAttribute) => new global::System.ComponentModel.DataAnnotations.EmailAddressAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.PhoneAttribute) => new global::System.ComponentModel.DataAnnotations.PhoneAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.UrlAttribute) => new global::System.ComponentModel.DataAnnotations.UrlAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.CreditCardAttribute) => new global::System.ComponentModel.DataAnnotations.CreditCardAttribute(), + _ when typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type) => + (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type)! + }; + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CustomValidationAttribute) && args.Length == 2) + { + // CustomValidationAttribute requires special handling + // First argument is a type, second is a method name + if (args[0] is global::System.Type validatingType && args[1] is string methodName) + { + attribute = new global::System.ComponentModel.DataAnnotations.CustomValidationAttribute(validatingType, methodName); + } + else + { + throw new global::System.ArgumentException($"Invalid arguments for CustomValidationAttribute: Type and method name required"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.StringLengthAttribute)) + { + if (args[0] is int maxLength) + { + attribute = new global::System.ComponentModel.DataAnnotations.StringLengthAttribute(maxLength); + } + else + { + throw new global::System.ArgumentException($"Invalid maxLength value for StringLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MinLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MinLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MinLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MaxLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MaxLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MaxLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute) && args.Length == 2) + { + if (args[0] is int min && args[1] is int max) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(min, max); + } + else if (args[0] is double dmin && args[1] is double dmax) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(dmin, dmax); + } + else + { + throw new global::System.ArgumentException($"Invalid range values for RangeAttribute: {args[0]}, {args[1]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute)) + { + if (args[0] is string pattern) + { + attribute = new global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute(pattern); + } + else + { + throw new global::System.ArgumentException($"Invalid pattern for RegularExpressionAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CompareAttribute)) + { + if (args[0] is string otherProperty) + { + attribute = new global::System.ComponentModel.DataAnnotations.CompareAttribute(otherProperty); + } + else + { + throw new global::System.ArgumentException($"Invalid otherProperty for CompareAttribute: {args[0]}"); + } + } + else if (typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type)) + { + var constructors = type.GetConstructors(); + var success = false; + attribute = null!; + + foreach (var constructor in constructors) + { + var parameters = constructor.GetParameters(); + if (parameters.Length != args.Length) + continue; + + var convertedArgs = new object[args.Length]; + var canUseConstructor = true; + + for (var i = 0; i < parameters.Length; i++) + { + try + { + if (args[i] != null && args[i].GetType() == parameters[i].ParameterType) + { + // Type already matches, use as-is + convertedArgs[i] = args[i]; + } + else + { + // Try to convert + convertedArgs[i] = global::System.Convert.ChangeType(args[i], parameters[i].ParameterType); + } + } + catch + { + canUseConstructor = false; + break; + } + } + + if (canUseConstructor) + { + attribute = (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type, convertedArgs)!; + success = true; + break; + } + } + + if (!success) + { + throw new global::System.ArgumentException($"Could not find a suitable constructor for validation attribute type: {type.FullName}"); + } + } + else + { + throw new global::System.ArgumentException($"Unsupported validation attribute type: {type.FullName}"); + } + + // Apply named arguments after construction + foreach (var namedArg in k.NamedArguments) + { + var prop = type.GetProperty(namedArg.Key); + if (prop != null && prop.CanWrite) + { + try + { + if (namedArg.Value != null && namedArg.Value.GetType() == prop.PropertyType) + { + // Type already matches, use as-is + prop.SetValue(attribute, namedArg.Value); + } + else + { + // Try to convert + prop.SetValue(attribute, global::System.Convert.ChangeType(namedArg.Value, prop.PropertyType)); + } + } + catch (global::System.Exception ex) + { + throw new global::System.ArgumentException($"Failed to set property {namedArg.Key} on {type.FullName}: {ex.Message}"); + } + } + } + + return attribute; + }); + } + } +} +"""; + + private static string EmitValidationAttributeForCreate(ValidationAttribute attr) + { + // Process constructor arguments - convert to appropriate typed objects + var processedArgs = new List(attr.Arguments.Count); + + foreach (var arg in attr.Arguments) + { + // Handle different types of arguments + if (arg.StartsWith("\"", StringComparison.OrdinalIgnoreCase) && arg.EndsWith("\"", StringComparison.OrdinalIgnoreCase)) + { + // String literal - remove quotes and pass as object + var stringValue = arg.Substring(1, arg.Length - 2).Replace("\\\"", "\""); + processedArgs.Add($"\"{stringValue}\""); + } + else if (arg.StartsWith("typeof(", StringComparison.OrdinalIgnoreCase) && arg.EndsWith(")", StringComparison.OrdinalIgnoreCase)) + { + // Type argument - pass directly + processedArgs.Add(arg); + } + else if (int.TryParse(arg, out var intValue)) + { + // Integer + processedArgs.Add(intValue.ToString(CultureInfo.InvariantCulture)); + } + else if (double.TryParse(arg, out var doubleValue)) + { + // Double + processedArgs.Add(doubleValue.ToString(CultureInfo.InvariantCulture) + "d"); + } + else if (bool.TryParse(arg, out var boolValue)) + { + // Boolean + processedArgs.Add(boolValue.ToString().ToLowerInvariant()); + } + else if (arg == "null") + { + // Null + processedArgs.Add("null"); + } + else + { + // Default to string for anything else + processedArgs.Add($"\"{arg.Replace("\"", "\\\"")}\""); + } + } + + var args = attr.Arguments.Count > 0 + ? $"[{string.Join(", ", processedArgs)}]" + : "[]"; + + // Process named arguments - ensure proper formatting for object dictionary + var namedArgsParts = new List(attr.NamedArguments.Count); + foreach (var pair in attr.NamedArguments) + { + // Convert the value based on its format + var valueStr = pair.Value; + string objectValue; + + if (valueStr.StartsWith("\"", StringComparison.OrdinalIgnoreCase) && valueStr.EndsWith("\"", StringComparison.OrdinalIgnoreCase)) + { + // String literal + objectValue = valueStr; + } + else if (valueStr.StartsWith("typeof(", StringComparison.OrdinalIgnoreCase) && valueStr.EndsWith(")", StringComparison.OrdinalIgnoreCase)) + { + // Type argument + objectValue = valueStr; + } + else if (int.TryParse(valueStr, out _)) + { + // Integer + objectValue = valueStr; + } + else if (double.TryParse(valueStr, out _)) + { + // Double + objectValue = valueStr + "d"; + } + else if (bool.TryParse(valueStr, out var boolVal)) + { + // Boolean + objectValue = boolVal.ToString().ToLowerInvariant(); + } + else if (valueStr == "null") + { + // Null + objectValue = "null"; + } + else + { + // Default to string for anything else + objectValue = $"\"{valueStr.Replace("\"", "\\\"")}\""; + } + + namedArgsParts.Add($"{{ \"{pair.Key}\", {objectValue} }}"); + } + + var namedArgs = attr.NamedArguments.Count > 0 + ? $"new global::System.Collections.Generic.Dictionary {{ {string.Join(", ", namedArgsParts)} }}" + : "[]"; + + // Use string interpolation with @ to prevent escaping issues in the error message + return $@"ValidationAttributeCache.GetOrCreateValidationAttribute(typeof({attr.ClassName}), {args}, {namedArgs}) ?? throw new global::System.InvalidOperationException(@""Failed to create validation attribute {attr.ClassName}"")"; + } + + private static string EmitTypeChecks(ImmutableArray validatableTypes) + { + var sw = new StringWriter(); + var cw = new CodeWriter(sw, baseIndent: 2); + foreach (var validatableType in validatableTypes) + { + var typeName = validatableType.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + cw.WriteLine($"if (type == typeof({typeName}))"); + cw.StartBlock(); + cw.WriteLine($"validatableInfo = Create{SanitizeTypeName(validatableType.Type.MetadataName)}();"); + cw.WriteLine("return true;"); + cw.EndBlock(); + } + return sw.ToString(); + } + + private static string EmitCreateMethods(ImmutableArray validatableTypes) + { + var sw = new StringWriter(); + var cw = new CodeWriter(sw, baseIndent: 2); + foreach (var validatableType in validatableTypes) + { + cw.WriteLine($@"private ValidatableTypeInfo Create{SanitizeTypeName(validatableType.Type.MetadataName)}()"); + cw.StartBlock(); + cw.WriteLine("return new GeneratedValidatableTypeInfo("); + cw.Indent++; + cw.WriteLine($"type: typeof({validatableType.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}),"); + if (validatableType.Members.IsDefaultOrEmpty) + { + cw.WriteLine("members: []"); + } + else + { + cw.WriteLine("members: ["); + cw.Indent++; + foreach (var member in validatableType.Members) + { + EmitValidatableMemberForCreate(member, cw); + } + cw.Indent--; + cw.WriteLine("]"); + } + cw.Indent--; + cw.WriteLine(");"); + cw.EndBlock(); + } + return sw.ToString(); + } + + private static void EmitValidatableMemberForCreate(ValidatableProperty member, CodeWriter cw) + { + var validationAttributes = member.Attributes.IsDefaultOrEmpty + ? "[]" + : $"[{string.Join(", ", member.Attributes.Select(EmitValidationAttributeForCreate))}]"; + cw.WriteLine("new GeneratedValidatablePropertyInfo("); + cw.Indent++; + cw.WriteLine($"containingType: typeof({member.ContainingType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}),"); + cw.WriteLine($"propertyType: typeof({member.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}),"); + cw.WriteLine($"name: \"{member.Name}\","); + cw.WriteLine($"displayName: \"{member.DisplayName}\","); + cw.WriteLine($"validationAttributes: {validationAttributes}"); + cw.Indent--; + cw.WriteLine("),"); + } + + private static string SanitizeTypeName(string typeName) + { + // Replace invalid characters with underscores and remove generic notation + return typeName + .Replace(".", "_") + .Replace("<", "_") + .Replace(">", "_") + .Replace(",", "_") + .Replace(" ", "_"); + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ISymbolExtensions.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ISymbolExtensions.cs new file mode 100644 index 000000000000..54efe204c1ec --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ISymbolExtensions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal static class ISymbolExtensions +{ + public static string GetDisplayName(this ISymbol property, INamedTypeSymbol displayAttribute) + { + var displayNameAttribute = property.GetAttributes() + .FirstOrDefault(attribute => + attribute.AttributeClass is { } attributeClass && + SymbolEqualityComparer.Default.Equals(attributeClass, displayAttribute)); + if (displayNameAttribute is not null) + { + if (displayNameAttribute.ConstructorArguments.Length > 0) + { + return displayNameAttribute.ConstructorArguments[0].Value?.ToString() ?? property.Name; + } + else if (displayNameAttribute.NamedArguments.Length > 0) + { + return displayNameAttribute.NamedArguments[0].Value.Value?.ToString() ?? property.Name; + } + return property.Name; + } + + return property.Name; + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs new file mode 100644 index 000000000000..c13feb21c495 --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs @@ -0,0 +1,88 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal static class ITypeSymbolExtensions +{ + public static bool IsEnumerable(this ITypeSymbol type, INamedTypeSymbol enumerable) + { + if (type.SpecialType == SpecialType.System_String) + { + return false; + } + + return type.ImplementsInterface(enumerable) || SymbolEqualityComparer.Default.Equals(type, enumerable); + } + + public static bool ImplementsValidationAttribute(this ITypeSymbol typeSymbol, INamedTypeSymbol validationAttributeSymbol) + { + var baseType = typeSymbol.BaseType; + while (baseType != null) + { + if (SymbolEqualityComparer.Default.Equals(baseType, validationAttributeSymbol)) + { + return true; + } + baseType = baseType.BaseType; + } + + return false; + } + + public static ITypeSymbol UnwrapType(this ITypeSymbol type, INamedTypeSymbol enumerable) + { + if (type.OriginalDefinition.SpecialType == SpecialType.System_Nullable_T) + { + // Extract the T from a Nullable + type = ((INamedTypeSymbol)type).TypeArguments[0]; + } + + if (type.NullableAnnotation == NullableAnnotation.Annotated) + { + // Extract the underlying type from a reference type + type = type.OriginalDefinition; + } + + if (type is INamedTypeSymbol namedType && namedType.IsEnumerable(enumerable)) + { + // Extract the T from an IEnumerable or List + type = namedType.TypeArguments[0]; + } + + return type; + } + + internal static bool ImplementsInterface(this ITypeSymbol type, ITypeSymbol interfaceType) + { + foreach (var iface in type.AllInterfaces) + { + if (SymbolEqualityComparer.Default.Equals(interfaceType, iface)) + { + return true; + } + } + return false; + } + + internal static ImmutableArray? GetJsonDerivedTypes(this ITypeSymbol type, INamedTypeSymbol jsonDerivedTypeAttribute) + { + var derivedTypes = ImmutableArray.CreateBuilder(); + foreach (var attribute in type.GetAttributes()) + { + if (SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, jsonDerivedTypeAttribute)) + { + var derivedType = (INamedTypeSymbol?)attribute.ConstructorArguments[0].Value; + if (derivedType is not null && !SymbolEqualityComparer.Default.Equals(derivedType, type)) + { + derivedTypes.Add(derivedType); + } + } + } + + return derivedTypes.Count == 0 ? null : derivedTypes.ToImmutable(); + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/IncrementalValuesProviderExtensions.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/IncrementalValuesProviderExtensions.cs new file mode 100644 index 000000000000..a95c12d657ee --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/IncrementalValuesProviderExtensions.cs @@ -0,0 +1,109 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal static class IncrementalValuesProviderExtensions +{ + public static IncrementalValuesProvider Distinct(this IncrementalValuesProvider source, IEqualityComparer comparer) + { + return source + .Collect() + .WithComparer(ImmutableArrayEqualityComparer.Instance) + .SelectMany((values, cancellationToken) => + { + if (values.IsEmpty) + { + return values; + } + + var results = ImmutableArray.CreateBuilder(values.Length); + HashSet set = new(comparer); + + foreach (var value in values) + { + if (set.Add(value)) + { + results.Add(value); + } + } + + return results.DrainToImmutable(); + }); + } + + public static IncrementalValuesProvider Concat( + this IncrementalValuesProvider> first, + IncrementalValuesProvider> second) + { + return first + .Combine(second.Collect()) + .SelectMany((tuple, _) => + { + if (tuple.Left.IsEmpty && tuple.Right.IsEmpty) + { + return []; + } + + var results = ImmutableArray.CreateBuilder(tuple.Left.Length + tuple.Right.Length); + results.AddRange(tuple.Left); + for (var i = 0; i < tuple.Right.Length; i++) + { + results.AddRange(tuple.Right[i]); + } + return results.DrainToImmutable(); + }); + } + + private sealed class ImmutableArrayEqualityComparer : IEqualityComparer> + { + public static readonly ImmutableArrayEqualityComparer Instance = new(); + + public bool Equals(ImmutableArray x, ImmutableArray y) + { + if (x.IsDefault) + { + return y.IsDefault; + } + else if (y.IsDefault) + { + return false; + } + + if (x.Length != y.Length) + { + return false; + } + + for (var i = 0; i < x.Length; i++) + { + if (!EqualityComparer.Default.Equals(x[i], y[i])) + { + return false; + } + } + + return true; + } + + public int GetHashCode(ImmutableArray obj) + { + if (obj.IsDefault) + { + return 0; + } + var hashCode = -450793227; + foreach (var item in obj) + { + hashCode = (hashCode * -1521134295) + EqualityComparer.Default.GetHashCode(item); + } + + return hashCode; + } + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/RequiredSymbols.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/RequiredSymbols.cs new file mode 100644 index 000000000000..0d67e2cca980 --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/RequiredSymbols.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class RequiredSymbols( + INamedTypeSymbol DisplayAttribute, + INamedTypeSymbol ValidationAttribute, + INamedTypeSymbol IEnumerable, + INamedTypeSymbol IValidatableObject, + INamedTypeSymbol JsonDerivedTypeAttribute, + INamedTypeSymbol RequiredAttribute, + INamedTypeSymbol CustomValidationAttribute +); diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableProperty.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableProperty.cs new file mode 100644 index 000000000000..658f27a82e6b --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableProperty.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class ValidatableProperty( + ITypeSymbol ContainingType, + ITypeSymbol Type, + string Name, + string DisplayName, + ImmutableArray Attributes +); diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableType.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableType.cs new file mode 100644 index 000000000000..c6d7e36f36a9 --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableType.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class ValidatableType( + ITypeSymbol Type, + ImmutableArray Members +); diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableTypeComparer.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableTypeComparer.cs new file mode 100644 index 000000000000..fcd99f51dc0b --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidatableTypeComparer.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed class ValidatableTypeComparer : IEqualityComparer +{ + public static ValidatableTypeComparer Instance { get; } = new(); + + public bool Equals(ValidatableType? x, ValidatableType? y) + { + if (x is null && y is null) + { + return true; + } + if (x is null || y is null) + { + return false; + } + return SymbolEqualityComparer.Default.Equals(x.Type, y.Type); + } + + public int GetHashCode(ValidatableType? obj) + { + return SymbolEqualityComparer.Default.GetHashCode(obj?.Type); + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidationAttribute.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidationAttribute.cs new file mode 100644 index 000000000000..c29e12a99c0d --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Models/ValidationAttribute.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +internal sealed record class ValidationAttribute( + string Name, + string ClassName, + List Arguments, + Dictionary NamedArguments, + bool IsCustomValidationAttribute +); diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AddValidation.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AddValidation.cs new file mode 100644 index 000000000000..fbf7673b19fd --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AddValidation.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; +using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + internal bool FindAddValidation(SyntaxNode syntaxNode, CancellationToken cancellationToken) + { + if (syntaxNode is InvocationExpressionSyntax + && syntaxNode.TryGetMapMethodName(out var method) + && method == "AddValidation") + { + return true; + } + return false; + } + + internal InterceptableLocation? TransformAddValidation(GeneratorSyntaxContext context, CancellationToken cancellationToken) + { + var node = (InvocationExpressionSyntax)context.Node; + var semanticModel = context.SemanticModel; + return semanticModel.GetInterceptableLocation(node, cancellationToken); + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AttributeParser.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AttributeParser.cs new file mode 100644 index 000000000000..5bea9a8ad218 --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.AttributeParser.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Threading; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + internal static bool ShouldTransformSymbolWithAttribute(SyntaxNode syntaxNode, CancellationToken cancellationToken) + { + return syntaxNode is ClassDeclarationSyntax; + } + + internal ImmutableArray TransformValidatableTypeWithAttribute(GeneratorAttributeSyntaxContext context, CancellationToken cancellationToken) + { + var validatableTypes = new HashSet(ValidatableTypeComparer.Instance); + List visitedTypes = []; + var requiredSymbols = ExtractRequiredSymbols(context.SemanticModel.Compilation, cancellationToken); + if (TryExtractValidatableType((ITypeSymbol)context.TargetSymbol, requiredSymbols, ref validatableTypes, ref visitedTypes)) + { + return [..validatableTypes]; + } + return []; + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.EndpointsParser.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.EndpointsParser.cs new file mode 100644 index 000000000000..da325a562d1f --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.EndpointsParser.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Immutable; +using System.Linq; +using System.Threading; +using Microsoft.AspNetCore.Analyzers.Infrastructure; +using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Operations; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + internal bool FindEndpoints(SyntaxNode syntaxNode, CancellationToken cancellationToken) + { + if (syntaxNode is InvocationExpressionSyntax + && syntaxNode.TryGetMapMethodName(out var method)) + { + return method == "MapMethods" || InvocationOperationExtensions.KnownMethods.Contains(method); + } + return false; + } + + internal IInvocationOperation? TransformEndpoints(GeneratorSyntaxContext context, CancellationToken cancellationToken) + { + if (context.Node is not InvocationExpressionSyntax node) + { + return null; + } + var operation = context.SemanticModel.GetOperation(node, cancellationToken); + AnalyzerDebug.Assert(operation != null, "Operation should not be null."); + return operation is IInvocationOperation invocationOperation + ? invocationOperation + : null; + } + + internal ImmutableArray ExtractValidatableEndpoint((IInvocationOperation? Operation, RequiredSymbols RequiredSymbols) input, CancellationToken cancellationToken) + { + AnalyzerDebug.Assert(input.Operation != null, "Operation should not be null."); + var validatableTypes = ExtractValidatableTypes(input.Operation, input.RequiredSymbols); + return validatableTypes; + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.RequiredSymbolsParser.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.RequiredSymbolsParser.cs new file mode 100644 index 000000000000..f7ba94e097e6 --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.RequiredSymbolsParser.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Threading; +using Microsoft.CodeAnalysis; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + internal RequiredSymbols ExtractRequiredSymbols(Compilation compilation, CancellationToken cancellationToken) + { + return new RequiredSymbols( + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.DisplayAttribute")!, + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.ValidationAttribute")!, + compilation.GetTypeByMetadataName("System.Collections.IEnumerable")!, + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.IValidatableObject")!, + compilation.GetTypeByMetadataName("System.Text.Json.Serialization.JsonDerivedTypeAttribute")!, + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.RequiredAttribute")!, + compilation.GetTypeByMetadataName("System.ComponentModel.DataAnnotations.CustomValidationAttribute")! + ); + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs new file mode 100644 index 000000000000..0085e53760c6 --- /dev/null +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs @@ -0,0 +1,146 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Threading; +using Microsoft.AspNetCore.Analyzers.Infrastructure; +using Microsoft.AspNetCore.Http.RequestDelegateGenerator.StaticRouteHandlerModel; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Operations; + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator; + +public sealed partial class ValidationsGenerator : IIncrementalGenerator +{ + private static readonly SymbolDisplayFormat _symbolDisplayFormat = new( + globalNamespaceStyle: SymbolDisplayGlobalNamespaceStyle.Included, + typeQualificationStyle: SymbolDisplayTypeQualificationStyle.NameAndContainingTypesAndNamespaces); + + internal ImmutableArray ExtractValidatableTypes(IInvocationOperation operation, RequiredSymbols requiredSymbols) + { + AnalyzerDebug.Assert(operation.SemanticModel != null, "SemanticModel should not be null."); + var parameters = operation.TryGetRouteHandlerMethod(operation.SemanticModel, out var method) + ? method.Parameters + : []; + var validatableTypes = new HashSet(ValidatableTypeComparer.Instance); + List visitedTypes = []; + foreach (var parameter in parameters) + { + _ = TryExtractValidatableType(parameter.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes); + } + return [.. validatableTypes]; + } + + internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet validatableTypes, ref List visitedTypes) + { + if (typeSymbol.SpecialType != SpecialType.None) + { + return false; + } + + if (visitedTypes.Contains(typeSymbol)) + { + return true; + } + + visitedTypes.Add(typeSymbol); + + // Extract validatable types discovered in base types of this type and add them to the top-level list. + var current = typeSymbol.BaseType; + while (current != null && current.SpecialType != SpecialType.System_Object) + { + _ = TryExtractValidatableType(current, requiredSymbols, ref validatableTypes, ref visitedTypes); + current = current.BaseType; + } + + // Extract validatable types discovered in members of this type and add them to the top-level list. + var members = ExtractValidatableMembers(typeSymbol, requiredSymbols, ref validatableTypes, ref visitedTypes); + + // Extract the validatable types discovered in the JsonDerivedTypeAttributes of this type and add them to the top-level list. + var derivedTypes = typeSymbol.GetJsonDerivedTypes(requiredSymbols.JsonDerivedTypeAttribute); + foreach (var derivedType in derivedTypes ?? []) + { + _ = TryExtractValidatableType(derivedType, requiredSymbols, ref validatableTypes, ref visitedTypes); + } + + // Add the type itself as a validatable type itself. + validatableTypes.Add(new ValidatableType( + Type: typeSymbol, + Members: members)); + + return true; + } + + internal ImmutableArray ExtractValidatableMembers(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet validatableTypes, ref List visitedTypes) + { + var members = new List(); + foreach (var member in typeSymbol.GetMembers().OfType()) + { + var hasValidatableType = TryExtractValidatableType(member.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes); + var attributes = ExtractValidationAttributes(member, requiredSymbols, out var isRequired); + members.Add(new ValidatableProperty( + ContainingType: member.ContainingType, + Type: member.Type, + Name: member.Name, + DisplayName: member.GetDisplayName(requiredSymbols.DisplayAttribute), + Attributes: attributes)); + } + + return [.. members]; + } + + public ImmutableArray ExtractPropertyTypes(ITypeSymbol type, CancellationToken cancellationToken) + { + var builder = ImmutableArray.CreateBuilder(); + var processed = new HashSet(SymbolEqualityComparer.Default); + + void Traverse(ITypeSymbol currentType) + { + if (currentType == null || currentType.SpecialType != SpecialType.None || processed.Contains(currentType)) + { + return; + } + + processed.Add(currentType); + builder.Add(currentType); + + foreach (var member in currentType.GetMembers().OfType()) + { + if (member.Type is ITypeSymbol propertyType) + { + Traverse(propertyType); + } + } + } + + Traverse(type); + return builder.ToImmutable(); + } + + internal static ImmutableArray ExtractValidationAttributes(ISymbol symbol, RequiredSymbols requiredSymbols, out bool isRequired) + { + var attributes = symbol.GetAttributes(); + if (attributes.Length == 0) + { + isRequired = false; + return []; + } + + // Continue with existing logic... + var validationAttributes = attributes + .Where(attribute => attribute.AttributeClass != null) + .Where(attribute => attribute.AttributeClass!.ImplementsValidationAttribute(requiredSymbols.ValidationAttribute)); + isRequired = validationAttributes.Any(attr => SymbolEqualityComparer.Default.Equals(attr.AttributeClass, requiredSymbols.RequiredAttribute)); + return [.. validationAttributes + .Where(attr => !SymbolEqualityComparer.Default.Equals(attr.AttributeClass, requiredSymbols.ValidationAttribute)) + .Select(attribute => new ValidationAttribute( + Name: symbol.Name + attribute.AttributeClass!.Name, + ClassName: attribute.AttributeClass!.ToDisplayString(_symbolDisplayFormat), + Arguments: [.. attribute.ConstructorArguments.Select(a => a.ToCSharpString())], + NamedArguments: attribute.NamedArguments.ToDictionary(namedArgument => namedArgument.Key, namedArgument => namedArgument.Value.ToCSharpString()), + IsCustomValidationAttribute: SymbolEqualityComparer.Default.Equals(attribute.AttributeClass, requiredSymbols.CustomValidationAttribute)))]; + } +} diff --git a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/ValidationsGenerator.cs b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/ValidationsGenerator.cs index ac7be3762c0d..0af7bc1f8615 100644 --- a/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/ValidationsGenerator.cs +++ b/src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/ValidationsGenerator.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Linq; using Microsoft.CodeAnalysis; namespace Microsoft.AspNetCore.Http.ValidationsGenerator; @@ -9,6 +10,41 @@ public sealed partial class ValidationsGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { - return; + // Resolve the symbols that will be required when making comparisons + // in future steps. + var requiredSymbols = context.CompilationProvider.Select(ExtractRequiredSymbols); + + // Find the builder.Services.AddValidation() call in the application. + var addValidation = context.SyntaxProvider.CreateSyntaxProvider( + predicate: FindAddValidation, + transform: TransformAddValidation + ); + // Extract types that have been marked with [ValidatableType]. + var validatableTypesWithAttribute = context.SyntaxProvider.ForAttributeWithMetadataName( + "Microsoft.AspNetCore.Http.Validation.ValidatableTypeAttribute", + predicate: ShouldTransformSymbolWithAttribute, + transform: TransformValidatableTypeWithAttribute + ); + // Extract all minimal API endpoints in the application. + var endpoints = context.SyntaxProvider + .CreateSyntaxProvider( + predicate: FindEndpoints, + transform: TransformEndpoints) + .Where(endpoint => endpoint is not null); + // Extract validatable types from all endpoints. + var validatableTypesFromEndpoints = endpoints + .Combine(requiredSymbols) + .Select(ExtractValidatableEndpoint); + // Extract all validatable types encountered in the type graph. + var validatableTypes = validatableTypesFromEndpoints + .Concat(validatableTypesWithAttribute) + .Distinct(ValidatableTypeComparer.Instance) + .Collect(); + + var emitInputs = addValidation + .Combine(validatableTypes); + + // Emit ValidatableTypeInfo for all validatable types. + context.RegisterSourceOutput(emitInputs, Emit); } } diff --git a/src/Http/Http.Extensions/test/ValidationEndpointConventionBuilderExtensionsTests.cs b/src/Http/Http.Extensions/test/ValidationEndpointConventionBuilderExtensionsTests.cs new file mode 100644 index 000000000000..c871c407bd31 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationEndpointConventionBuilderExtensionsTests.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Text; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http.Metadata; +using Microsoft.AspNetCore.InternalTesting; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Http.Extensions.Tests; + +public class ValidationEndpointConventionBuilderExtensionsTests : LoggedTest +{ + [Fact] + public async Task DisableValidation_PreventsValidationFilterRegistration() + { + // Arrange + var services = new ServiceCollection(); + services.AddValidation(); + services.AddSingleton(LoggerFactory); + var serviceProvider = services.BuildServiceProvider(); + + var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(serviceProvider)); + + // Act - Create two endpoints - one with validation disabled, one without + var regularBuilder = builder.MapGet("test-enabled", ([Range(5, 10)] int param) => "Validation enabled here."); + var disabledBuilder = builder.MapGet("test-disabled", ([Range(5, 10)] int param) => "Validation disabled here."); + + disabledBuilder.DisableValidation(); + + // Build the endpoints + var dataSource = Assert.Single(builder.DataSources); + var endpoints = dataSource.Endpoints; + + // Assert + Assert.Equal(2, endpoints.Count); + + // Get filter factories from both endpoints + var regularEndpoint = endpoints[0]; + var disabledEndpoint = endpoints[1]; + + // Verify the disabled endpoint has the IDisableValidationMetadata + Assert.Contains(disabledEndpoint.Metadata, m => m is IDisableValidationMetadata); + + // Verify that invalid arguments on the disabled endpoint do not trigger validation + var context = new DefaultHttpContext + { + RequestServices = serviceProvider + }; + context.Request.Method = "GET"; + context.Request.QueryString = new QueryString("?param=15"); + var ms = new MemoryStream(); + context.Response.Body = ms; + + await disabledEndpoint.RequestDelegate(context); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + Assert.Equal("Validation disabled here.", Encoding.UTF8.GetString(ms.ToArray())); + + context = new DefaultHttpContext + { + RequestServices = serviceProvider + }; + context.Request.Method = "GET"; + context.Request.QueryString = new QueryString("?param=15"); + await regularEndpoint.RequestDelegate(context); + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + } + + private class DefaultEndpointRouteBuilder(IApplicationBuilder applicationBuilder) : IEndpointRouteBuilder + { + private IApplicationBuilder ApplicationBuilder { get; } = applicationBuilder ?? throw new ArgumentNullException(nameof(applicationBuilder)); + public IApplicationBuilder CreateApplicationBuilder() => ApplicationBuilder.New(); + public ICollection DataSources { get; } = []; + public IServiceProvider ServiceProvider => ApplicationBuilder.ApplicationServices; + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.ComplexType.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.ComplexType.cs new file mode 100644 index 000000000000..5dd548cf184c --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.ComplexType.cs @@ -0,0 +1,373 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestBase +{ + [Fact] + public async Task CanValidateComplexTypes() + { + // Arrange + var source = """ +using System; +using System.ComponentModel.DataAnnotations; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Validation; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); + +builder.Services.AddValidation(); + +var app = builder.Build(); + +app.MapPost("/complex-type", (ComplexType complexType) => Results.Ok("Passed"!)); + +app.Run(); + +public class ComplexType +{ + [Range(10, 100)] + public int IntegerWithRange { get; set; } = 10; + + [Range(10, 100), Display(Name = "Valid identifier")] + public int IntegerWithRangeAndDisplayName { get; set; } = 50; + + [Required] + public SubType PropertyWithMemberAttributes { get; set; } = new SubType(); + + public SubType PropertyWithoutMemberAttributes { get; set; } = new SubType(); + + public SubTypeWithInheritance PropertyWithInheritance { get; set; } = new SubTypeWithInheritance(); + + public List ListOfSubTypes { get; set; } = []; + + [DerivedValidation(ErrorMessage = "Value must be an even number")] + public int IntegerWithDerivedValidationAttribute { get; set; } + + [CustomValidation(typeof(CustomValidators), nameof(CustomValidators.Validate))] + public int IntegerWithCustomValidation { get; set; } = 0; + + [DerivedValidation, Range(10, 100)] + public int PropertyWithMultipleAttributes { get; set; } = 10; +} + +public class DerivedValidationAttribute : ValidationAttribute +{ + public override bool IsValid(object? value) => value is int number && number % 2 == 0; +} + +public class SubType +{ + [Required] + public string RequiredProperty { get; set; } = "some-value"; + + [StringLength(10)] + public string? StringWithLength { get; set; } +} + +public class SubTypeWithInheritance : SubType +{ + [EmailAddress] + public string? EmailString { get; set; } +} + +public static class CustomValidators +{ + public static ValidationResult Validate(int number, ValidationContext validationContext) + { + var parent = (ComplexType)validationContext.ObjectInstance; + + if (parent.IntegerWithRange == number) + { + return new ValidationResult( + "Can't use the same number value in two properties on the same class.", + new[] { validationContext.MemberName }); + } + + return ValidationResult.Success; + } +} +"""; + await Verify(source, out var compilation); + await VerifyEndpoint(compilation, "/complex-type", async (endpoint, serviceProvider) => + { + await InvalidIntegerWithRangeProducesError(endpoint); + await InvalidIntegerWithRangeAndDisplayNameProducesError(endpoint); + await MissingRequiredSubtypePropertyProducesError(endpoint); + await InvalidRequiredSubtypePropertyProducesError(endpoint); + await InvalidSubTypeWithInheritancePropertyProducesError(endpoint); + await InvalidListOfSubTypesProducesError(endpoint); + await InvalidPropertyWithDerivedValidationAttributeProducesError(endpoint); + await InvalidPropertyWithMultipleAttributesProducesError(endpoint); + await InvalidPropertyWithCustomValidationProducesError(endpoint); + await ValidInputProducesNoWarnings(endpoint); + + async Task InvalidIntegerWithRangeProducesError(Endpoint endpoint) + { + + var payload = """ + { + "IntegerWithRange": 5 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithRange", kvp.Key); + Assert.Equal("The field IntegerWithRange must be between 10 and 100.", kvp.Value.Single()); + }); + } + + async Task InvalidIntegerWithRangeAndDisplayNameProducesError(Endpoint endpoint) + { + var payload = """ + { + "IntegerWithRangeAndDisplayName": 5 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithRangeAndDisplayName", kvp.Key); + Assert.Equal("The field Valid identifier must be between 10 and 100.", kvp.Value.Single()); + }); + } + + async Task MissingRequiredSubtypePropertyProducesError(Endpoint endpoint) + { + var payload = """ + { + "PropertyWithMemberAttributes": null + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("PropertyWithMemberAttributes", kvp.Key); + Assert.Equal("The PropertyWithMemberAttributes field is required.", kvp.Value.Single()); + }); + } + + async Task InvalidRequiredSubtypePropertyProducesError(Endpoint endpoint) + { + var payload = """ + { + "PropertyWithMemberAttributes": { + "RequiredProperty": "", + "StringWithLength": "way-too-long" + } + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, + kvp => + { + Assert.Equal("PropertyWithMemberAttributes.RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithMemberAttributes.StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }); + } + + async Task InvalidSubTypeWithInheritancePropertyProducesError(Endpoint endpoint) + { + var payload = """ + { + "PropertyWithInheritance": { + "RequiredProperty": "", + "StringWithLength": "way-too-long", + "EmailString": "not-an-email" + } + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, + kvp => + { + Assert.Equal("PropertyWithInheritance.EmailString", kvp.Key); + Assert.Equal("The EmailString field is not a valid e-mail address.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithInheritance.RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("PropertyWithInheritance.StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }); + } + + async Task InvalidListOfSubTypesProducesError(Endpoint endpoint) + { + var payload = """ + { + "ListOfSubTypes": [ + { + "RequiredProperty": "", + "StringWithLength": "way-too-long" + }, + { + "RequiredProperty": "valid", + "StringWithLength": "way-too-long" + }, + { + "RequiredProperty": "valid", + "StringWithLength": "valid" + } + ] + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, + kvp => + { + Assert.Equal("ListOfSubTypes[0].RequiredProperty", kvp.Key); + Assert.Equal("The RequiredProperty field is required.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("ListOfSubTypes[0].StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }, + kvp => + { + Assert.Equal("ListOfSubTypes[1].StringWithLength", kvp.Key); + Assert.Equal("The field StringWithLength must be a string with a maximum length of 10.", kvp.Value.Single()); + }); + } + + async Task InvalidPropertyWithDerivedValidationAttributeProducesError(Endpoint endpoint) + { + var payload = """ + { + "IntegerWithDerivedValidationAttribute": 5 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithDerivedValidationAttribute", kvp.Key); + Assert.Equal("Value must be an even number", kvp.Value.Single()); + }); + } + + async Task InvalidPropertyWithMultipleAttributesProducesError(Endpoint endpoint) + { + var payload = """ + { + "PropertyWithMultipleAttributes": 5 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("PropertyWithMultipleAttributes", kvp.Key); + Assert.Collection(kvp.Value, + error => + { + Assert.Equal("The field PropertyWithMultipleAttributes is invalid.", error); + }, + error => + { + Assert.Equal("The field PropertyWithMultipleAttributes must be between 10 and 100.", error); + }); + }); + } + + async Task InvalidPropertyWithCustomValidationProducesError(Endpoint endpoint) + { + var payload = """ + { + "IntegerWithRange": 42, + "IntegerWithCustomValidation": 42 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + + await endpoint.RequestDelegate(context); + + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, kvp => + { + Assert.Equal("IntegerWithCustomValidation", kvp.Key); + var error = Assert.Single(kvp.Value); + Assert.Equal("Can't use the same number value in two properties on the same class.", error); + }); + } + + async Task ValidInputProducesNoWarnings(Endpoint endpoint) + { + var payload = """ + { + "IntegerWithRange": 50, + "IntegerWithRangeAndDisplayName": 50, + "PropertyWithMemberAttributes": { + "RequiredProperty": "valid", + "StringWithLength": "valid" + }, + "PropertyWithoutMemberAttributes": { + "RequiredProperty": "valid", + "StringWithLength": "valid" + }, + "PropertyWithInheritance": { + "RequiredProperty": "valid", + "StringWithLength": "valid", + "EmailString": "test@example.com" + }, + "ListOfSubTypes": [], + "IntegerWithDerivedValidationAttribute": 2, + "IntegerWithCustomValidation": 0, + "PropertyWithMultipleAttributes": 12 + } + """; + var context = CreateHttpContextWithPayload(payload, serviceProvider); + await endpoint.RequestDelegate(context); + + Assert.Equal(200, context.Response.StatusCode); + } + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.IValidatableObject.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.IValidatableObject.cs new file mode 100644 index 000000000000..30de31e208b0 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.IValidatableObject.cs @@ -0,0 +1,193 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestBase +{ + [Fact] + public async Task CanValidateIValidatableObject() + { + var source = """ +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Validation; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); +builder.Services.AddSingleton(); +builder.Services.AddValidation(); + +var app = builder.Build(); + +app.MapPost("/validatable-object", (ComplexValidatableType model) => Results.Ok()); + +app.Run(); + +public class ComplexValidatableType: IValidatableObject +{ + [Display(Name = "Value 1")] + public int Value1 { get; set; } + + [EmailAddress] + [Required] + public required string Value2 { get; set; } = "test@example.com"; + + public ValidatableSubType SubType { get; set; } = new ValidatableSubType(); + + public IEnumerable Validate(ValidationContext validationContext) + { + var rangeService = (IRangeService?)validationContext.GetService(typeof(IRangeService)); + var minimum = rangeService?.GetMinimum(); + var maximum = rangeService?.GetMaximum(); + if (Value1 < minimum || Value1 > maximum) + { + yield return new ValidationResult($"The field {nameof(Value1)} must be between {minimum} and {maximum}.", [nameof(Value1)]); + } + } +} + +public class SubType +{ + [Required] + public string RequiredProperty { get; set; } = "some-value"; + + [StringLength(10)] + public string? StringWithLength { get; set; } +} + +public class ValidatableSubType : SubType, IValidatableObject +{ + public string Value3 { get; set; } = "some-value"; + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Value3 != "some-value") + { + yield return new ValidationResult($"The field {validationContext.DisplayName} must be 'some-value'.", [nameof(Value3)]); + } + } +} + +public interface IRangeService +{ + int GetMinimum(); + int GetMaximum(); +} + +public class RangeService : IRangeService +{ + public int GetMinimum() => 10; + public int GetMaximum() => 100; +} +"""; + + await Verify(source, out var compilation); + await VerifyEndpoint(compilation, "/validatable-object", async (endpoint, serviceProvider) => + { + await ValidateMethodCalledIfPropertyValidationsFail(); + await ValidateForSubtypeInvokedFirst(); + await ValidateForTopLevelInvoked(); + + async Task ValidateMethodCalledIfPropertyValidationsFail() + { + var httpContext = CreateHttpContextWithPayload(""" + { + "Value1": 5, + "Value2": "", + "SubType": { + "Value3": "foo", + "RequiredProperty": "", + "StringWithLength": "" + } + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value2", error.Key); + Assert.Collection(error.Value, + msg => Assert.Equal("The Value2 field is required.", msg)); + }, + error => + { + Assert.Equal("SubType.RequiredProperty", error.Key); + Assert.Equal("The RequiredProperty field is required.", error.Value.Single()); + }, + error => + { + Assert.Equal("SubType.Value3", error.Key); + Assert.Equal("The field ValidatableSubType must be 'some-value'.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value1 must be between 10 and 100.", error.Value.Single()); + }); + } + + async Task ValidateForSubtypeInvokedFirst() + { + var httpContext = CreateHttpContextWithPayload(""" + { + "Value1": 5, + "Value2": "test@test.com", + "SubType": { + "Value3": "foo", + "RequiredProperty": "some-value-2", + "StringWithLength": "element" + } + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("SubType.Value3", error.Key); + Assert.Equal("The field ValidatableSubType must be 'some-value'.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value1 must be between 10 and 100.", error.Value.Single()); + }); + } + + async Task ValidateForTopLevelInvoked() + { + var httpContext = CreateHttpContextWithPayload(""" + { + "Value1": 5, + "Value2": "test@test.com", + "SubType": { + "Value3": "some-value", + "RequiredProperty": "some-value-2", + "StringWithLength": "element" + } + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value1 must be between 10 and 100.", error.Value.Single()); + }); + } + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Parameters.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Parameters.cs index 6a6def95f50a..3be6e96fe7a9 100644 --- a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Parameters.cs +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Parameters.cs @@ -11,11 +11,18 @@ public async Task CanValidateParameters() var source = """ using System; using System.ComponentModel.DataAnnotations; +using System.Collections.Generic; +using System.Linq; using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Validation; using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; var builder = WebApplication.CreateBuilder(); +builder.Services.AddValidation(); + var app = builder.Build(); app.MapGet("/params", ( @@ -33,12 +40,45 @@ public class CustomValidationAttribute : ValidationAttribute } """; await Verify(source, out var compilation); - VerifyEndpoint(compilation, "/params", async endpoint => + await VerifyEndpoint(compilation, "/params", async (endpoint, serviceProvider) => { - var context = CreateHttpContext(); + var context = CreateHttpContext(serviceProvider); context.Request.QueryString = new QueryString("?value1=5&value2=5&value3=&value4=3&value5=5"); await endpoint.RequestDelegate(context); - Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + var problemDetails = await AssertBadRequest(context); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("value1", error.Key); + Assert.Equal("The field value1 must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("value2", error.Key); + Assert.Equal("The field Valid identifier must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("value3", error.Key); + Assert.Equal("The value3 field is required.", error.Value.Single()); + }, + error => + { + Assert.Equal("value4", error.Key); + Assert.Equal("Value must be an even number", error.Value.Single()); + }, + error => + { + Assert.Equal("value5", error.Key); + Assert.Collection(error.Value, error => + { + Assert.Equal("The field value5 is invalid.", error); + }, + error => + { + Assert.Equal("The field value5 must be between 10 and 100.", error); + }); + }); }); } } diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Polymorphism.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Polymorphism.cs new file mode 100644 index 000000000000..54148e784a0a --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Polymorphism.cs @@ -0,0 +1,202 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestBase +{ + [Fact] + public async Task CanValidatePolymorphicTypes() + { + var source = """ +using System; +using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.Text.Json.Serialization; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Validation; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); +builder.Services.AddValidation(); + +var app = builder.Build(); + +app.MapPost("/basic-polymorphism", (BaseType model) => Results.Ok()); +app.MapPost("/validatable-polymorphism", (BaseValidatableType model) => Results.Ok()); +app.MapPost("/polymorphism-container", (ContainerType model) => Results.Ok()); + +app.Run(); + +public class ContainerType +{ + public BaseType BaseType { get; set; } = new BaseType(); + public BaseValidatableType BaseValidatableType { get; set; } = new BaseValidatableType(); +} + +[JsonDerivedType(typeof(BaseType), typeDiscriminator: "base")] +[JsonDerivedType(typeof(DerivedType), typeDiscriminator: "derived")] +public class BaseType +{ + [Display(Name = "Value 1")] + [Range(10, 100)] + public int Value1 { get; set; } + + [EmailAddress] + [Required] + public string Value2 { get; set; } = "test@example.com"; +} + +public class DerivedType : BaseType +{ + [Base64String] + public string? Value3 { get; set; } +} + +[JsonDerivedType(typeof(BaseValidatableType), typeDiscriminator: "base")] +[JsonDerivedType(typeof(DerivedValidatableType), typeDiscriminator: "derived")] +public class BaseValidatableType : IValidatableObject +{ + [Display(Name = "Value 1")] + public int Value1 { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (Value1 < 10 || Value1 > 100) + { + yield return new ValidationResult("The field Value 1 must be between 10 and 100.", new[] { nameof(Value1) }); + } + } +} + +public class DerivedValidatableType : BaseValidatableType +{ + [EmailAddress] + public required string Value3 { get; set; } +} +"""; + await Verify(source, out var compilation); + + await VerifyEndpoint(compilation, "/basic-polymorphism", async (endpoint, serviceProvider) => + { + var httpContext = CreateHttpContextWithPayload(""" + { + "$type": "derived", + "Value1": 5, + "Value2": "invalid-email", + "Value3": "invalid-base64" + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value3", error.Key); + Assert.Equal("The Value3 field is not a valid Base64 encoding.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value2", error.Key); + Assert.Equal("The Value2 field is not a valid e-mail address.", error.Value.Single()); + }); + }); + + await VerifyEndpoint(compilation, "/validatable-polymorphism", async (endpoint, serviceProvider) => + { + var httpContext = CreateHttpContextWithPayload(""" + { + "$type": "derived", + "Value1": 5, + "Value3": "invalid-email" + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value3", error.Key); + Assert.Equal("The Value3 field is not a valid e-mail address.", error.Value.Single()); + }, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }); + + httpContext = CreateHttpContextWithPayload(""" + { + "$type": "derived", + "Value1": 5, + "Value3": "test@example.com" + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails1 = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails1.Errors, + error => + { + Assert.Equal("Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }); + }); + + await VerifyEndpoint(compilation, "/polymorphism-container", async (endpoint, serviceProvider) => + { + var httpContext = CreateHttpContextWithPayload(""" + { + "BaseType": { + "$type": "derived", + "Value1": 5, + "Value2": "invalid-email", + "Value3": "invalid-base64" + }, + "BaseValidatableType": { + "$type": "derived", + "Value1": 5, + "Value3": "test@example.com" + } + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("BaseType.Value3", error.Key); + Assert.Equal("The Value3 field is not a valid Base64 encoding.", error.Value.Single()); + }, + error => + { + Assert.Equal("BaseType.Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("BaseType.Value2", error.Key); + Assert.Equal("The Value2 field is not a valid e-mail address.", error.Value.Single()); + }, + error => + { + Assert.Equal("BaseValidatableType.Value1", error.Key); + Assert.Equal("The field Value 1 must be between 10 and 100.", error.Value.Single()); + }); + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Recursion.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Recursion.cs new file mode 100644 index 000000000000..4affa35f8997 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.Recursion.cs @@ -0,0 +1,158 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; + +public partial class ValidationsGeneratorTests : ValidationsGeneratorTestBase +{ + [Fact] + public async Task CanValidateRecursiveTypes() + { + var source = """ +using System.ComponentModel.DataAnnotations; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +var builder = WebApplication.CreateBuilder(); +builder.Services.AddValidation(options => +{ + options.MaxDepth = 8; +}); + +var app = builder.Build(); + +app.MapPost("/recursive-type", (RecursiveType model) => Results.Ok()); + +app.Run(); + +public class RecursiveType +{ + [Range(10, 100)] + public int Value { get; set; } + public RecursiveType? Next { get; set; } +} +"""; + await Verify(source, out var compilation); + + await VerifyEndpoint(compilation, "/recursive-type", async (endpoint, serviceProvider) => + { + await ThrowsExceptionForDeeplyNestedType(endpoint); + await ValidatesTypeWithLimitedNesting(endpoint); + + async Task ThrowsExceptionForDeeplyNestedType(Endpoint endpoint) + { + var httpContext = CreateHttpContextWithPayload(""" + { + "value": 1, + "next": { + "value": 2, + "next": { + "value": 3, + "next": { + "value": 4, + "next": { + "value": 5, + "next": { + "value": 6, + "next": { + "value": 7, + "next": { + "value": 8, + "next": { + "value": 9, + "next": { + "value": 10 + } + } + } + } + } + } + } + } + } + } + """, serviceProvider); + + var exception = await Assert.ThrowsAsync(async () => await endpoint.RequestDelegate(httpContext)); + } + + async Task ValidatesTypeWithLimitedNesting(Endpoint endpoint) + { + var httpContext = CreateHttpContextWithPayload(""" + { + "value": 1, + "next": { + "value": 2, + "next": { + "value": 3, + "next": { + "value": 4, + "next": { + "value": 5, + "next": { + "value": 6, + "next": { + "value": 7, + "next": { + "value": 8 + } + } + } + } + } + } + } + } + """, serviceProvider); + + await endpoint.RequestDelegate(httpContext); + + var problemDetails = await AssertBadRequest(httpContext); + Assert.Collection(problemDetails.Errors, + error => + { + Assert.Equal("Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Next.Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Next.Next.Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }, + error => + { + Assert.Equal("Next.Next.Next.Next.Next.Next.Next.Value", error.Key); + Assert.Equal("The field Value must be between 10 and 100.", error.Value.Single()); + }); + } + }); + } +} diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestBase.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestBase.cs index b8740de3e928..6b3be7472c87 100644 --- a/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestBase.cs +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGeneratorTestBase.cs @@ -2,12 +2,15 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Globalization; using System.Reflection; using System.Runtime.Loader; using System.Text; +using System.Text.Json; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting.Server; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Http.Validation; using Microsoft.AspNetCore.InternalTesting; using Microsoft.AspNetCore.Routing; using Microsoft.CodeAnalysis; @@ -16,8 +19,8 @@ using Microsoft.CodeAnalysis.Text; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; -using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using static Microsoft.AspNetCore.Http.Generators.Tests.RequestDelegateCreationTestBase; namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; @@ -25,7 +28,7 @@ namespace Microsoft.AspNetCore.Http.ValidationsGenerator.Tests; public class ValidationsGeneratorTestBase : LoggedTestBase { private static readonly CSharpParseOptions ParseOptions = new CSharpParseOptions(LanguageVersion.Preview) - .WithFeatures([new KeyValuePair("InterceptorsNamespaces", "Microsoft.AspNetCore.Http.Validations.Generated")]); + .WithFeatures([new KeyValuePair("InterceptorsNamespaces", "Microsoft.AspNetCore.Http.Validation.Generated")]); internal static Task Verify(string source, out Compilation compilation) { @@ -50,9 +53,11 @@ internal static Task Verify(string source, out Compilation compilation) MetadataReference.CreateFromFile(typeof(ValidateOptionsResult).Assembly.Location), MetadataReference.CreateFromFile(typeof(IHttpMethodMetadata).Assembly.Location), MetadataReference.CreateFromFile(typeof(IResult).Assembly.Location), - MetadataReference.CreateFromFile(typeof(HttpJsonServiceExtensions).Assembly.Location) + MetadataReference.CreateFromFile(typeof(HttpJsonServiceExtensions).Assembly.Location), + MetadataReference.CreateFromFile(typeof(IValidatableInfoResolver).Assembly.Location), + MetadataReference.CreateFromFile(typeof(EndpointFilterFactoryContext).Assembly.Location), ]); - var inputCompilation = CSharpCompilation.Create("OpenApiXmlCommentGeneratorSample", + var inputCompilation = CSharpCompilation.Create("ValidationsGeneratorSample", [CSharpSyntaxTree.ParseText(source, options: ParseOptions, path: "Program.cs")], references, new CSharpCompilationOptions(OutputKind.ConsoleApplication)); @@ -60,13 +65,52 @@ internal static Task Verify(string source, out Compilation compilation) var driver = CSharpGeneratorDriver.Create(generators: [generator.AsSourceGenerator()], parseOptions: ParseOptions); return Verifier .Verify(driver.RunGeneratorsAndUpdateCompilation(inputCompilation, out compilation, out var diagnostics)) + .AutoVerify() .UseDirectory(SkipOnHelixAttribute.OnHelix() ? Path.Combine(Environment.GetEnvironmentVariable("HELIX_WORKITEM_ROOT"), "ValidationsGenerator", "snapshots") : "snapshots"); } - internal static void VerifyEndpoint(Compilation compilation, string routePattern, Action verifyFunc) + internal static void VerifyValidatableType(Compilation compilation, string typeName, Action verifyFunc) { + if (TryResolveServicesFromCompilation(compilation, targetAssemblyName: "Microsoft.AspNetCore.Http.Abstractions", typeName: "Microsoft.AspNetCore.Http.Validation.ValidationOptions", out var services, out var serviceType, out var outputAssemblyName) is false) + { + throw new InvalidOperationException("Could not resolve services from compilation."); + } + var targetAssembly = AppDomain.CurrentDomain.GetAssemblies().FirstOrDefault(assembly => assembly.GetName().Name == outputAssemblyName); + var type = targetAssembly.GetType(typeName, throwOnError: false); + + // Get IOptions first + var optionsType = typeof(IOptions<>).MakeGenericType(serviceType); + var optionsInstance = services.GetService(optionsType) ?? throw new InvalidOperationException("Could not resolve IOptions."); + + // Then access the Value property + var valueProperty = optionsType.GetProperty("Value"); + var service = (ValidationOptions)valueProperty.GetValue(optionsInstance) ?? throw new InvalidOperationException("Could not resolve ValidationOptions."); + if (service.TryGetValidatableTypeInfo(type, out var validatableTypeInfo) is false) + { + throw new InvalidOperationException("Could not resolve ValidatableTypeInfo."); + } + verifyFunc(validatableTypeInfo); + } + + internal static async Task VerifyEndpoint(Compilation compilation, string routePattern, Func verifyFunc) + { + if (TryResolveServicesFromCompilation(compilation, targetAssemblyName: "Microsoft.AspNetCore.Routing", typeName: "Microsoft.AspNetCore.Routing.EndpointDataSource", out var services, out var serviceType, out var outputAssemblyName) is false) + { + throw new InvalidOperationException("Could not resolve services from compilation."); + } + var service = services.GetService(serviceType) ?? throw new InvalidOperationException("Could not resolve EndpointDataSource."); + var endpoints = (IReadOnlyList)service.GetType().GetProperty("Endpoints", BindingFlags.Instance | BindingFlags.Public).GetValue(service); + var endpoint = endpoints.FirstOrDefault(endpoint => endpoint is RouteEndpoint routeEndpoint && routeEndpoint.RoutePattern.RawText == routePattern); + await verifyFunc(endpoint, services); + } + + private static bool TryResolveServicesFromCompilation(Compilation compilation, string targetAssemblyName, string typeName, out IServiceProvider serviceProvider, out Type serviceType, out string outputAssemblyName) + { + serviceProvider = null; + serviceType = null; + outputAssemblyName = $"TestProject-{Guid.NewGuid()}"; var assemblyName = compilation.AssemblyName; var symbolsName = Path.ChangeExtension(assemblyName, "pdb"); @@ -76,7 +120,7 @@ internal static void VerifyEndpoint(Compilation compilation, string routePattern var emitOptions = new EmitOptions( debugInformationFormat: DebugInformationFormat.PortablePdb, pdbFilePath: symbolsName, - outputNameOverride: $"TestProject-{Guid.NewGuid()}"); + outputNameOverride: outputAssemblyName); var embeddedTexts = new List(); @@ -135,28 +179,24 @@ void OnEntryPointExit(Exception exception) if (factory == null) { - return; + return false; } var services = ((IHost)factory([$"--{HostDefaults.ApplicationKey}={assemblyName}"])).Services; var applicationLifetime = services.GetRequiredService(); - using (var registration = applicationLifetime.ApplicationStarted.Register(() => waitForStartTcs.TrySetResult(0))) - { - waitForStartTcs.Task.Wait(); - var targetAssembly = AppDomain.CurrentDomain.GetAssemblies().FirstOrDefault(assembly => assembly.GetName().Name == "Microsoft.AspNetCore.Routing"); - var serviceType = targetAssembly.GetType("Microsoft.AspNetCore.Routing.EndpointDataSource", throwOnError: false); - - if (serviceType == null) - { - return; - } + using var registration = applicationLifetime.ApplicationStarted.Register(() => waitForStartTcs.TrySetResult(0)); + waitForStartTcs.Task.Wait(); + var targetAssembly = AppDomain.CurrentDomain.GetAssemblies().FirstOrDefault(assembly => assembly.GetName().Name == targetAssemblyName); + serviceType = targetAssembly.GetType(typeName, throwOnError: false); - var service = services.GetService(serviceType) ?? throw new InvalidOperationException("Could not resolve EndpointDataSource."); - var endpoints = (IReadOnlyList)serviceType.GetProperty("Endpoints", BindingFlags.Instance | BindingFlags.Public).GetValue(service); - var endpoint = endpoints.FirstOrDefault(endpoint => endpoint is RouteEndpoint routeEndpoint && routeEndpoint.RoutePattern.RawText == routePattern); - verifyFunc(endpoint); + if (serviceType == null) + { + return false; } + + serviceProvider = services; + return true; } private sealed class NoopHostLifetime : IHostLifetime @@ -510,10 +550,10 @@ private sealed class HostAbortedException : Exception } } - internal HttpContext CreateHttpContext(IServiceProvider serviceProvider = null) + internal HttpContext CreateHttpContext(IServiceProvider serviceProvider) { var httpContext = new DefaultHttpContext(); - httpContext.RequestServices = serviceProvider ?? CreateServiceProvider(); + httpContext.RequestServices = serviceProvider; var outStream = new MemoryStream(); httpContext.Response.Body = outStream; @@ -521,14 +561,24 @@ internal HttpContext CreateHttpContext(IServiceProvider serviceProvider = null) return httpContext; } - internal ServiceProvider CreateServiceProvider(Action configureServices = null) + internal HttpContext CreateHttpContextWithPayload(string requestData, IServiceProvider serviceProvider = null) { - var serviceCollection = new ServiceCollection(); - serviceCollection.AddSingleton(LoggerFactory); - if (configureServices is not null) - { - configureServices(serviceCollection); - } - return serviceCollection.BuildServiceProvider(); + var httpContext = CreateHttpContext(serviceProvider); + httpContext.Features.Set(new RequestBodyDetectionFeature(true)); + httpContext.Request.Headers["Content-Type"] = "application/json"; + + var stream = new MemoryStream(System.Text.Encoding.UTF8.GetBytes(requestData)); + httpContext.Request.Body = stream; + httpContext.Request.Headers["Content-Length"] = stream.Length.ToString(CultureInfo.InvariantCulture); + return httpContext; + } + + internal async Task AssertBadRequest(HttpContext context) + { + Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode); + context.Response.Body.Position = 0; + using var reader = new StreamReader(context.Response.Body); + var responseBody = await reader.ReadToEndAsync(); + return JsonSerializer.Deserialize(responseBody); } } diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateComplexTypes#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateComplexTypes#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..5ca75e48b6e9 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateComplexTypes#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,413 @@ +//HintName: ValidatableInfoResolver.g.cs +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + private readonly global::System.ComponentModel.DataAnnotations.ValidationAttribute[] _validationAttributes; + + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName, + global::System.ComponentModel.DataAnnotations.ValidationAttribute[] validationAttributes) : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; + if (type == typeof(global::SubType)) + { + validatableInfo = CreateSubType(); + return true; + } + if (type == typeof(global::SubTypeWithInheritance)) + { + validatableInfo = CreateSubTypeWithInheritance(); + return true; + } + if (type == typeof(global::ComplexType)) + { + validatableInfo = CreateComplexType(); + return true; + } + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; + return false; + } + + private ValidatableTypeInfo CreateSubType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::SubType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "RequiredProperty", + displayName: "RequiredProperty", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute), [], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RequiredAttribute")] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "StringWithLength", + displayName: "StringWithLength", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.StringLengthAttribute), [10], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.StringLengthAttribute")] + ), + ] + ); + } + private ValidatableTypeInfo CreateSubTypeWithInheritance() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::SubTypeWithInheritance), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubTypeWithInheritance), + propertyType: typeof(string), + name: "EmailString", + displayName: "EmailString", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.EmailAddressAttribute), [], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.EmailAddressAttribute")] + ), + ] + ); + } + private ValidatableTypeInfo CreateComplexType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::ComplexType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithRange", + displayName: "IntegerWithRange", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute), [10, 100], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RangeAttribute")] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithRangeAndDisplayName", + displayName: "Valid identifier", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute), [10, 100], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RangeAttribute")] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubType), + name: "PropertyWithMemberAttributes", + displayName: "PropertyWithMemberAttributes", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute), [], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RequiredAttribute")] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubType), + name: "PropertyWithoutMemberAttributes", + displayName: "PropertyWithoutMemberAttributes", + validationAttributes: [] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubTypeWithInheritance), + name: "PropertyWithInheritance", + displayName: "PropertyWithInheritance", + validationAttributes: [] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::System.Collections.Generic.List), + name: "ListOfSubTypes", + displayName: "ListOfSubTypes", + validationAttributes: [] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithDerivedValidationAttribute", + displayName: "IntegerWithDerivedValidationAttribute", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::DerivedValidationAttribute), [], new global::System.Collections.Generic.Dictionary { { "ErrorMessage", "Value must be an even number" } }) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::DerivedValidationAttribute")] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithCustomValidation", + displayName: "IntegerWithCustomValidation", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.CustomValidationAttribute), [typeof(CustomValidators), "Validate"], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.CustomValidationAttribute")] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "PropertyWithMultipleAttributes", + displayName: "PropertyWithMultipleAttributes", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::DerivedValidationAttribute), [], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::DerivedValidationAttribute"), ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute), [10, 100], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RangeAttribute")] + ), + ] + ); + } + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "I71YCOnkIuFyp29JNyKEXIEBAABQcm9ncmFtLmNz")] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type AttributeType, object[] Arguments, global::System.Collections.Generic.Dictionary NamedArguments); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute? GetOrCreateValidationAttribute( + global::System.Type attributeType, + object[] arguments, + global::System.Collections.Generic.Dictionary namedArguments) + { + var key = new CacheKey(attributeType, arguments, namedArguments); + return _cache.GetOrAdd(key, static k => + { + var type = k.AttributeType; + var args = k.Arguments; + + global::System.ComponentModel.DataAnnotations.ValidationAttribute attribute; + + if (args.Length == 0) + { + attribute = type switch + { + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute) => new global::System.ComponentModel.DataAnnotations.RequiredAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.EmailAddressAttribute) => new global::System.ComponentModel.DataAnnotations.EmailAddressAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.PhoneAttribute) => new global::System.ComponentModel.DataAnnotations.PhoneAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.UrlAttribute) => new global::System.ComponentModel.DataAnnotations.UrlAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.CreditCardAttribute) => new global::System.ComponentModel.DataAnnotations.CreditCardAttribute(), + _ when typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type) => + (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type)! + }; + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CustomValidationAttribute) && args.Length == 2) + { + // CustomValidationAttribute requires special handling + // First argument is a type, second is a method name + if (args[0] is global::System.Type validatingType && args[1] is string methodName) + { + attribute = new global::System.ComponentModel.DataAnnotations.CustomValidationAttribute(validatingType, methodName); + } + else + { + throw new global::System.ArgumentException($"Invalid arguments for CustomValidationAttribute: Type and method name required"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.StringLengthAttribute)) + { + if (args[0] is int maxLength) + { + attribute = new global::System.ComponentModel.DataAnnotations.StringLengthAttribute(maxLength); + } + else + { + throw new global::System.ArgumentException($"Invalid maxLength value for StringLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MinLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MinLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MinLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MaxLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MaxLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MaxLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute) && args.Length == 2) + { + if (args[0] is int min && args[1] is int max) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(min, max); + } + else if (args[0] is double dmin && args[1] is double dmax) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(dmin, dmax); + } + else + { + throw new global::System.ArgumentException($"Invalid range values for RangeAttribute: {args[0]}, {args[1]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute)) + { + if (args[0] is string pattern) + { + attribute = new global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute(pattern); + } + else + { + throw new global::System.ArgumentException($"Invalid pattern for RegularExpressionAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CompareAttribute)) + { + if (args[0] is string otherProperty) + { + attribute = new global::System.ComponentModel.DataAnnotations.CompareAttribute(otherProperty); + } + else + { + throw new global::System.ArgumentException($"Invalid otherProperty for CompareAttribute: {args[0]}"); + } + } + else if (typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type)) + { + var constructors = type.GetConstructors(); + var success = false; + attribute = null!; + + foreach (var constructor in constructors) + { + var parameters = constructor.GetParameters(); + if (parameters.Length != args.Length) + continue; + + var convertedArgs = new object[args.Length]; + var canUseConstructor = true; + + for (var i = 0; i < parameters.Length; i++) + { + try + { + if (args[i] != null && args[i].GetType() == parameters[i].ParameterType) + { + // Type already matches, use as-is + convertedArgs[i] = args[i]; + } + else + { + // Try to convert + convertedArgs[i] = global::System.Convert.ChangeType(args[i], parameters[i].ParameterType); + } + } + catch + { + canUseConstructor = false; + break; + } + } + + if (canUseConstructor) + { + attribute = (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type, convertedArgs)!; + success = true; + break; + } + } + + if (!success) + { + throw new global::System.ArgumentException($"Could not find a suitable constructor for validation attribute type: {type.FullName}"); + } + } + else + { + throw new global::System.ArgumentException($"Unsupported validation attribute type: {type.FullName}"); + } + + // Apply named arguments after construction + foreach (var namedArg in k.NamedArguments) + { + var prop = type.GetProperty(namedArg.Key); + if (prop != null && prop.CanWrite) + { + try + { + if (namedArg.Value != null && namedArg.Value.GetType() == prop.PropertyType) + { + // Type already matches, use as-is + prop.SetValue(attribute, namedArg.Value); + } + else + { + // Try to convert + prop.SetValue(attribute, global::System.Convert.ChangeType(namedArg.Value, prop.PropertyType)); + } + } + catch (global::System.Exception ex) + { + throw new global::System.ArgumentException($"Failed to set property {namedArg.Key} on {type.FullName}: {ex.Message}"); + } + } + } + + return attribute; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateIValidatableObject#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateIValidatableObject#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..231483eb939e --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateIValidatableObject#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,371 @@ +//HintName: ValidatableInfoResolver.g.cs +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + private readonly global::System.ComponentModel.DataAnnotations.ValidationAttribute[] _validationAttributes; + + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName, + global::System.ComponentModel.DataAnnotations.ValidationAttribute[] validationAttributes) : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; + if (type == typeof(global::SubType)) + { + validatableInfo = CreateSubType(); + return true; + } + if (type == typeof(global::ValidatableSubType)) + { + validatableInfo = CreateValidatableSubType(); + return true; + } + if (type == typeof(global::ComplexValidatableType)) + { + validatableInfo = CreateComplexValidatableType(); + return true; + } + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; + return false; + } + + private ValidatableTypeInfo CreateSubType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::SubType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "RequiredProperty", + displayName: "RequiredProperty", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute), [], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RequiredAttribute")] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "StringWithLength", + displayName: "StringWithLength", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.StringLengthAttribute), [10], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.StringLengthAttribute")] + ), + ] + ); + } + private ValidatableTypeInfo CreateValidatableSubType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::ValidatableSubType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ValidatableSubType), + propertyType: typeof(string), + name: "Value3", + displayName: "Value3", + validationAttributes: [] + ), + ] + ); + } + private ValidatableTypeInfo CreateComplexValidatableType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::ComplexValidatableType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexValidatableType), + propertyType: typeof(int), + name: "Value1", + displayName: "Value 1", + validationAttributes: [] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexValidatableType), + propertyType: typeof(string), + name: "Value2", + displayName: "Value2", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.EmailAddressAttribute), [], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.EmailAddressAttribute"), ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute), [], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RequiredAttribute")] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexValidatableType), + propertyType: typeof(global::ValidatableSubType), + name: "SubType", + displayName: "SubType", + validationAttributes: [] + ), + ] + ); + } + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "uQnZr9MSHY6ZEeLkq015qrABAABQcm9ncmFtLmNz")] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type AttributeType, object[] Arguments, global::System.Collections.Generic.Dictionary NamedArguments); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute? GetOrCreateValidationAttribute( + global::System.Type attributeType, + object[] arguments, + global::System.Collections.Generic.Dictionary namedArguments) + { + var key = new CacheKey(attributeType, arguments, namedArguments); + return _cache.GetOrAdd(key, static k => + { + var type = k.AttributeType; + var args = k.Arguments; + + global::System.ComponentModel.DataAnnotations.ValidationAttribute attribute; + + if (args.Length == 0) + { + attribute = type switch + { + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute) => new global::System.ComponentModel.DataAnnotations.RequiredAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.EmailAddressAttribute) => new global::System.ComponentModel.DataAnnotations.EmailAddressAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.PhoneAttribute) => new global::System.ComponentModel.DataAnnotations.PhoneAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.UrlAttribute) => new global::System.ComponentModel.DataAnnotations.UrlAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.CreditCardAttribute) => new global::System.ComponentModel.DataAnnotations.CreditCardAttribute(), + _ when typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type) => + (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type)! + }; + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CustomValidationAttribute) && args.Length == 2) + { + // CustomValidationAttribute requires special handling + // First argument is a type, second is a method name + if (args[0] is global::System.Type validatingType && args[1] is string methodName) + { + attribute = new global::System.ComponentModel.DataAnnotations.CustomValidationAttribute(validatingType, methodName); + } + else + { + throw new global::System.ArgumentException($"Invalid arguments for CustomValidationAttribute: Type and method name required"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.StringLengthAttribute)) + { + if (args[0] is int maxLength) + { + attribute = new global::System.ComponentModel.DataAnnotations.StringLengthAttribute(maxLength); + } + else + { + throw new global::System.ArgumentException($"Invalid maxLength value for StringLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MinLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MinLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MinLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MaxLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MaxLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MaxLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute) && args.Length == 2) + { + if (args[0] is int min && args[1] is int max) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(min, max); + } + else if (args[0] is double dmin && args[1] is double dmax) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(dmin, dmax); + } + else + { + throw new global::System.ArgumentException($"Invalid range values for RangeAttribute: {args[0]}, {args[1]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute)) + { + if (args[0] is string pattern) + { + attribute = new global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute(pattern); + } + else + { + throw new global::System.ArgumentException($"Invalid pattern for RegularExpressionAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CompareAttribute)) + { + if (args[0] is string otherProperty) + { + attribute = new global::System.ComponentModel.DataAnnotations.CompareAttribute(otherProperty); + } + else + { + throw new global::System.ArgumentException($"Invalid otherProperty for CompareAttribute: {args[0]}"); + } + } + else if (typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type)) + { + var constructors = type.GetConstructors(); + var success = false; + attribute = null!; + + foreach (var constructor in constructors) + { + var parameters = constructor.GetParameters(); + if (parameters.Length != args.Length) + continue; + + var convertedArgs = new object[args.Length]; + var canUseConstructor = true; + + for (var i = 0; i < parameters.Length; i++) + { + try + { + if (args[i] != null && args[i].GetType() == parameters[i].ParameterType) + { + // Type already matches, use as-is + convertedArgs[i] = args[i]; + } + else + { + // Try to convert + convertedArgs[i] = global::System.Convert.ChangeType(args[i], parameters[i].ParameterType); + } + } + catch + { + canUseConstructor = false; + break; + } + } + + if (canUseConstructor) + { + attribute = (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type, convertedArgs)!; + success = true; + break; + } + } + + if (!success) + { + throw new global::System.ArgumentException($"Could not find a suitable constructor for validation attribute type: {type.FullName}"); + } + } + else + { + throw new global::System.ArgumentException($"Unsupported validation attribute type: {type.FullName}"); + } + + // Apply named arguments after construction + foreach (var namedArg in k.NamedArguments) + { + var prop = type.GetProperty(namedArg.Key); + if (prop != null && prop.CanWrite) + { + try + { + if (namedArg.Value != null && namedArg.Value.GetType() == prop.PropertyType) + { + // Type already matches, use as-is + prop.SetValue(attribute, namedArg.Value); + } + else + { + // Try to convert + prop.SetValue(attribute, global::System.Convert.ChangeType(namedArg.Value, prop.PropertyType)); + } + } + catch (global::System.Exception ex) + { + throw new global::System.ArgumentException($"Failed to set property {namedArg.Key} on {type.FullName}: {ex.Message}"); + } + } + } + + return attribute; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateParameters#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateParameters#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..2e43723688ea --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateParameters#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,290 @@ +//HintName: ValidatableInfoResolver.g.cs +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + private readonly global::System.ComponentModel.DataAnnotations.ValidationAttribute[] _validationAttributes; + + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName, + global::System.ComponentModel.DataAnnotations.ValidationAttribute[] validationAttributes) : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; + return false; + } + + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "85TK7bWNSSMP7r/P9i3t43YBAABQcm9ncmFtLmNz")] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type AttributeType, object[] Arguments, global::System.Collections.Generic.Dictionary NamedArguments); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute? GetOrCreateValidationAttribute( + global::System.Type attributeType, + object[] arguments, + global::System.Collections.Generic.Dictionary namedArguments) + { + var key = new CacheKey(attributeType, arguments, namedArguments); + return _cache.GetOrAdd(key, static k => + { + var type = k.AttributeType; + var args = k.Arguments; + + global::System.ComponentModel.DataAnnotations.ValidationAttribute attribute; + + if (args.Length == 0) + { + attribute = type switch + { + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute) => new global::System.ComponentModel.DataAnnotations.RequiredAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.EmailAddressAttribute) => new global::System.ComponentModel.DataAnnotations.EmailAddressAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.PhoneAttribute) => new global::System.ComponentModel.DataAnnotations.PhoneAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.UrlAttribute) => new global::System.ComponentModel.DataAnnotations.UrlAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.CreditCardAttribute) => new global::System.ComponentModel.DataAnnotations.CreditCardAttribute(), + _ when typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type) => + (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type)! + }; + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CustomValidationAttribute) && args.Length == 2) + { + // CustomValidationAttribute requires special handling + // First argument is a type, second is a method name + if (args[0] is global::System.Type validatingType && args[1] is string methodName) + { + attribute = new global::System.ComponentModel.DataAnnotations.CustomValidationAttribute(validatingType, methodName); + } + else + { + throw new global::System.ArgumentException($"Invalid arguments for CustomValidationAttribute: Type and method name required"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.StringLengthAttribute)) + { + if (args[0] is int maxLength) + { + attribute = new global::System.ComponentModel.DataAnnotations.StringLengthAttribute(maxLength); + } + else + { + throw new global::System.ArgumentException($"Invalid maxLength value for StringLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MinLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MinLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MinLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MaxLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MaxLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MaxLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute) && args.Length == 2) + { + if (args[0] is int min && args[1] is int max) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(min, max); + } + else if (args[0] is double dmin && args[1] is double dmax) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(dmin, dmax); + } + else + { + throw new global::System.ArgumentException($"Invalid range values for RangeAttribute: {args[0]}, {args[1]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute)) + { + if (args[0] is string pattern) + { + attribute = new global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute(pattern); + } + else + { + throw new global::System.ArgumentException($"Invalid pattern for RegularExpressionAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CompareAttribute)) + { + if (args[0] is string otherProperty) + { + attribute = new global::System.ComponentModel.DataAnnotations.CompareAttribute(otherProperty); + } + else + { + throw new global::System.ArgumentException($"Invalid otherProperty for CompareAttribute: {args[0]}"); + } + } + else if (typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type)) + { + var constructors = type.GetConstructors(); + var success = false; + attribute = null!; + + foreach (var constructor in constructors) + { + var parameters = constructor.GetParameters(); + if (parameters.Length != args.Length) + continue; + + var convertedArgs = new object[args.Length]; + var canUseConstructor = true; + + for (var i = 0; i < parameters.Length; i++) + { + try + { + if (args[i] != null && args[i].GetType() == parameters[i].ParameterType) + { + // Type already matches, use as-is + convertedArgs[i] = args[i]; + } + else + { + // Try to convert + convertedArgs[i] = global::System.Convert.ChangeType(args[i], parameters[i].ParameterType); + } + } + catch + { + canUseConstructor = false; + break; + } + } + + if (canUseConstructor) + { + attribute = (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type, convertedArgs)!; + success = true; + break; + } + } + + if (!success) + { + throw new global::System.ArgumentException($"Could not find a suitable constructor for validation attribute type: {type.FullName}"); + } + } + else + { + throw new global::System.ArgumentException($"Unsupported validation attribute type: {type.FullName}"); + } + + // Apply named arguments after construction + foreach (var namedArg in k.NamedArguments) + { + var prop = type.GetProperty(namedArg.Key); + if (prop != null && prop.CanWrite) + { + try + { + if (namedArg.Value != null && namedArg.Value.GetType() == prop.PropertyType) + { + // Type already matches, use as-is + prop.SetValue(attribute, namedArg.Value); + } + else + { + // Try to convert + prop.SetValue(attribute, global::System.Convert.ChangeType(namedArg.Value, prop.PropertyType)); + } + } + catch (global::System.Exception ex) + { + throw new global::System.ArgumentException($"Failed to set property {namedArg.Key} on {type.FullName}: {ex.Message}"); + } + } + } + + return attribute; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidatePolymorphicTypes#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidatePolymorphicTypes#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..52ef1a3a71d7 --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidatePolymorphicTypes#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,404 @@ +//HintName: ValidatableInfoResolver.g.cs +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + private readonly global::System.ComponentModel.DataAnnotations.ValidationAttribute[] _validationAttributes; + + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName, + global::System.ComponentModel.DataAnnotations.ValidationAttribute[] validationAttributes) : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; + if (type == typeof(global::DerivedType)) + { + validatableInfo = CreateDerivedType(); + return true; + } + if (type == typeof(global::BaseType)) + { + validatableInfo = CreateBaseType(); + return true; + } + if (type == typeof(global::DerivedValidatableType)) + { + validatableInfo = CreateDerivedValidatableType(); + return true; + } + if (type == typeof(global::BaseValidatableType)) + { + validatableInfo = CreateBaseValidatableType(); + return true; + } + if (type == typeof(global::ContainerType)) + { + validatableInfo = CreateContainerType(); + return true; + } + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; + return false; + } + + private ValidatableTypeInfo CreateDerivedType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::DerivedType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::DerivedType), + propertyType: typeof(string), + name: "Value3", + displayName: "Value3", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.Base64StringAttribute), [], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.Base64StringAttribute")] + ), + ] + ); + } + private ValidatableTypeInfo CreateBaseType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::BaseType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::BaseType), + propertyType: typeof(int), + name: "Value1", + displayName: "Value 1", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute), [10, 100], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RangeAttribute")] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::BaseType), + propertyType: typeof(string), + name: "Value2", + displayName: "Value2", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.EmailAddressAttribute), [], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.EmailAddressAttribute"), ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute), [], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RequiredAttribute")] + ), + ] + ); + } + private ValidatableTypeInfo CreateDerivedValidatableType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::DerivedValidatableType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::DerivedValidatableType), + propertyType: typeof(string), + name: "Value3", + displayName: "Value3", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.EmailAddressAttribute), [], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.EmailAddressAttribute")] + ), + ] + ); + } + private ValidatableTypeInfo CreateBaseValidatableType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::BaseValidatableType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::BaseValidatableType), + propertyType: typeof(int), + name: "Value1", + displayName: "Value 1", + validationAttributes: [] + ), + ] + ); + } + private ValidatableTypeInfo CreateContainerType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::ContainerType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ContainerType), + propertyType: typeof(global::BaseType), + name: "BaseType", + displayName: "BaseType", + validationAttributes: [] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ContainerType), + propertyType: typeof(global::BaseValidatableType), + name: "BaseValidatableType", + displayName: "BaseValidatableType", + validationAttributes: [] + ), + ] + ); + } + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "2dHhUDyiknXJLaQw/hXAKIgBAABQcm9ncmFtLmNz")] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type AttributeType, object[] Arguments, global::System.Collections.Generic.Dictionary NamedArguments); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute? GetOrCreateValidationAttribute( + global::System.Type attributeType, + object[] arguments, + global::System.Collections.Generic.Dictionary namedArguments) + { + var key = new CacheKey(attributeType, arguments, namedArguments); + return _cache.GetOrAdd(key, static k => + { + var type = k.AttributeType; + var args = k.Arguments; + + global::System.ComponentModel.DataAnnotations.ValidationAttribute attribute; + + if (args.Length == 0) + { + attribute = type switch + { + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute) => new global::System.ComponentModel.DataAnnotations.RequiredAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.EmailAddressAttribute) => new global::System.ComponentModel.DataAnnotations.EmailAddressAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.PhoneAttribute) => new global::System.ComponentModel.DataAnnotations.PhoneAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.UrlAttribute) => new global::System.ComponentModel.DataAnnotations.UrlAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.CreditCardAttribute) => new global::System.ComponentModel.DataAnnotations.CreditCardAttribute(), + _ when typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type) => + (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type)! + }; + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CustomValidationAttribute) && args.Length == 2) + { + // CustomValidationAttribute requires special handling + // First argument is a type, second is a method name + if (args[0] is global::System.Type validatingType && args[1] is string methodName) + { + attribute = new global::System.ComponentModel.DataAnnotations.CustomValidationAttribute(validatingType, methodName); + } + else + { + throw new global::System.ArgumentException($"Invalid arguments for CustomValidationAttribute: Type and method name required"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.StringLengthAttribute)) + { + if (args[0] is int maxLength) + { + attribute = new global::System.ComponentModel.DataAnnotations.StringLengthAttribute(maxLength); + } + else + { + throw new global::System.ArgumentException($"Invalid maxLength value for StringLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MinLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MinLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MinLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MaxLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MaxLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MaxLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute) && args.Length == 2) + { + if (args[0] is int min && args[1] is int max) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(min, max); + } + else if (args[0] is double dmin && args[1] is double dmax) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(dmin, dmax); + } + else + { + throw new global::System.ArgumentException($"Invalid range values for RangeAttribute: {args[0]}, {args[1]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute)) + { + if (args[0] is string pattern) + { + attribute = new global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute(pattern); + } + else + { + throw new global::System.ArgumentException($"Invalid pattern for RegularExpressionAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CompareAttribute)) + { + if (args[0] is string otherProperty) + { + attribute = new global::System.ComponentModel.DataAnnotations.CompareAttribute(otherProperty); + } + else + { + throw new global::System.ArgumentException($"Invalid otherProperty for CompareAttribute: {args[0]}"); + } + } + else if (typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type)) + { + var constructors = type.GetConstructors(); + var success = false; + attribute = null!; + + foreach (var constructor in constructors) + { + var parameters = constructor.GetParameters(); + if (parameters.Length != args.Length) + continue; + + var convertedArgs = new object[args.Length]; + var canUseConstructor = true; + + for (var i = 0; i < parameters.Length; i++) + { + try + { + if (args[i] != null && args[i].GetType() == parameters[i].ParameterType) + { + // Type already matches, use as-is + convertedArgs[i] = args[i]; + } + else + { + // Try to convert + convertedArgs[i] = global::System.Convert.ChangeType(args[i], parameters[i].ParameterType); + } + } + catch + { + canUseConstructor = false; + break; + } + } + + if (canUseConstructor) + { + attribute = (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type, convertedArgs)!; + success = true; + break; + } + } + + if (!success) + { + throw new global::System.ArgumentException($"Could not find a suitable constructor for validation attribute type: {type.FullName}"); + } + } + else + { + throw new global::System.ArgumentException($"Unsupported validation attribute type: {type.FullName}"); + } + + // Apply named arguments after construction + foreach (var namedArg in k.NamedArguments) + { + var prop = type.GetProperty(namedArg.Key); + if (prop != null && prop.CanWrite) + { + try + { + if (namedArg.Value != null && namedArg.Value.GetType() == prop.PropertyType) + { + // Type already matches, use as-is + prop.SetValue(attribute, namedArg.Value); + } + else + { + // Try to convert + prop.SetValue(attribute, global::System.Convert.ChangeType(namedArg.Value, prop.PropertyType)); + } + } + catch (global::System.Exception ex) + { + throw new global::System.ArgumentException($"Failed to set property {namedArg.Key} on {type.FullName}: {ex.Message}"); + } + } + } + + return attribute; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateRecursiveTypes#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateRecursiveTypes#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..a2b601211c0e --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateRecursiveTypes#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,317 @@ +//HintName: ValidatableInfoResolver.g.cs +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + private readonly global::System.ComponentModel.DataAnnotations.ValidationAttribute[] _validationAttributes; + + public GeneratedValidatablePropertyInfo( + global::System.Type containingType, + global::System.Type propertyType, + string name, + string displayName, + global::System.ComponentModel.DataAnnotations.ValidationAttribute[] validationAttributes) : base(containingType, propertyType, name, displayName) + { + _validationAttributes = validationAttributes; + } + + protected override global::System.ComponentModel.DataAnnotations.ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + global::System.Type type, + ValidatablePropertyInfo[] members) : base(type, members) { } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public bool TryGetValidatableTypeInfo(global::System.Type type, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; + if (type == typeof(global::RecursiveType)) + { + validatableInfo = CreateRecursiveType(); + return true; + } + + return false; + } + + // No-ops, rely on runtime code for ParameterInfo-based resolution + public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo, out global::Microsoft.AspNetCore.Http.Validation.IValidatableInfo validatableInfo) + { + validatableInfo = null; + return false; + } + + private ValidatableTypeInfo CreateRecursiveType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::RecursiveType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::RecursiveType), + propertyType: typeof(int), + name: "Value", + displayName: "Value", + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute), [10, 100], []) ?? throw new global::System.InvalidOperationException(@"Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RangeAttribute")] + ), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::RecursiveType), + propertyType: typeof(global::RecursiveType), + name: "Next", + displayName: "Next", + validationAttributes: [] + ), + ] + ); + } + + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "wNU90FRNrQG/m6cp7QRaZQYBAABQcm9ncmFtLmNz")] + public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(global::System.Type AttributeType, object[] Arguments, global::System.Collections.Generic.Dictionary NamedArguments); + private static readonly global::System.Collections.Concurrent.ConcurrentDictionary _cache = new(); + + public static global::System.ComponentModel.DataAnnotations.ValidationAttribute? GetOrCreateValidationAttribute( + global::System.Type attributeType, + object[] arguments, + global::System.Collections.Generic.Dictionary namedArguments) + { + var key = new CacheKey(attributeType, arguments, namedArguments); + return _cache.GetOrAdd(key, static k => + { + var type = k.AttributeType; + var args = k.Arguments; + + global::System.ComponentModel.DataAnnotations.ValidationAttribute attribute; + + if (args.Length == 0) + { + attribute = type switch + { + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute) => new global::System.ComponentModel.DataAnnotations.RequiredAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.EmailAddressAttribute) => new global::System.ComponentModel.DataAnnotations.EmailAddressAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.PhoneAttribute) => new global::System.ComponentModel.DataAnnotations.PhoneAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.UrlAttribute) => new global::System.ComponentModel.DataAnnotations.UrlAttribute(), + global::System.Type t when t == typeof(global::System.ComponentModel.DataAnnotations.CreditCardAttribute) => new global::System.ComponentModel.DataAnnotations.CreditCardAttribute(), + _ when typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type) => + (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type)! + }; + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CustomValidationAttribute) && args.Length == 2) + { + // CustomValidationAttribute requires special handling + // First argument is a type, second is a method name + if (args[0] is global::System.Type validatingType && args[1] is string methodName) + { + attribute = new global::System.ComponentModel.DataAnnotations.CustomValidationAttribute(validatingType, methodName); + } + else + { + throw new global::System.ArgumentException($"Invalid arguments for CustomValidationAttribute: Type and method name required"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.StringLengthAttribute)) + { + if (args[0] is int maxLength) + { + attribute = new global::System.ComponentModel.DataAnnotations.StringLengthAttribute(maxLength); + } + else + { + throw new global::System.ArgumentException($"Invalid maxLength value for StringLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MinLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MinLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MinLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.MaxLengthAttribute)) + { + if (args[0] is int length) + { + attribute = new global::System.ComponentModel.DataAnnotations.MaxLengthAttribute(length); + } + else + { + throw new global::System.ArgumentException($"Invalid length value for MaxLengthAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute) && args.Length == 2) + { + if (args[0] is int min && args[1] is int max) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(min, max); + } + else if (args[0] is double dmin && args[1] is double dmax) + { + attribute = new global::System.ComponentModel.DataAnnotations.RangeAttribute(dmin, dmax); + } + else + { + throw new global::System.ArgumentException($"Invalid range values for RangeAttribute: {args[0]}, {args[1]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute)) + { + if (args[0] is string pattern) + { + attribute = new global::System.ComponentModel.DataAnnotations.RegularExpressionAttribute(pattern); + } + else + { + throw new global::System.ArgumentException($"Invalid pattern for RegularExpressionAttribute: {args[0]}"); + } + } + else if (type == typeof(global::System.ComponentModel.DataAnnotations.CompareAttribute)) + { + if (args[0] is string otherProperty) + { + attribute = new global::System.ComponentModel.DataAnnotations.CompareAttribute(otherProperty); + } + else + { + throw new global::System.ArgumentException($"Invalid otherProperty for CompareAttribute: {args[0]}"); + } + } + else if (typeof(global::System.ComponentModel.DataAnnotations.ValidationAttribute).IsAssignableFrom(type)) + { + var constructors = type.GetConstructors(); + var success = false; + attribute = null!; + + foreach (var constructor in constructors) + { + var parameters = constructor.GetParameters(); + if (parameters.Length != args.Length) + continue; + + var convertedArgs = new object[args.Length]; + var canUseConstructor = true; + + for (var i = 0; i < parameters.Length; i++) + { + try + { + if (args[i] != null && args[i].GetType() == parameters[i].ParameterType) + { + // Type already matches, use as-is + convertedArgs[i] = args[i]; + } + else + { + // Try to convert + convertedArgs[i] = global::System.Convert.ChangeType(args[i], parameters[i].ParameterType); + } + } + catch + { + canUseConstructor = false; + break; + } + } + + if (canUseConstructor) + { + attribute = (global::System.ComponentModel.DataAnnotations.ValidationAttribute)global::System.Activator.CreateInstance(type, convertedArgs)!; + success = true; + break; + } + } + + if (!success) + { + throw new global::System.ArgumentException($"Could not find a suitable constructor for validation attribute type: {type.FullName}"); + } + } + else + { + throw new global::System.ArgumentException($"Unsupported validation attribute type: {type.FullName}"); + } + + // Apply named arguments after construction + foreach (var namedArg in k.NamedArguments) + { + var prop = type.GetProperty(namedArg.Key); + if (prop != null && prop.CanWrite) + { + try + { + if (namedArg.Value != null && namedArg.Value.GetType() == prop.PropertyType) + { + // Type already matches, use as-is + prop.SetValue(attribute, namedArg.Value); + } + else + { + // Try to convert + prop.SetValue(attribute, global::System.Convert.ChangeType(namedArg.Value, prop.PropertyType)); + } + } + catch (global::System.Exception ex) + { + throw new global::System.ArgumentException($"Failed to set property {namedArg.Key} on {type.FullName}: {ex.Message}"); + } + } + } + + return attribute; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateTypesWithAttribute#ValidatableInfoResolver.g.verified.cs b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateTypesWithAttribute#ValidatableInfoResolver.g.verified.cs new file mode 100644 index 000000000000..61971e891c2b --- /dev/null +++ b/src/Http/Http.Extensions/test/ValidationsGenerator/snapshots/ValidationsGeneratorTests.CanValidateTypesWithAttribute#ValidatableInfoResolver.g.verified.cs @@ -0,0 +1,407 @@ +//HintName: ValidatableInfoResolver.g.cs +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ +#nullable enable + +namespace System.Runtime.CompilerServices +{ + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + [AttributeUsage(AttributeTargets.Method, AllowMultiple = true)] + file sealed class InterceptsLocationAttribute : System.Attribute + { + public InterceptsLocationAttribute(int version, string data) + { + } + } +} + +namespace Microsoft.AspNetCore.Http.Validation.Generated +{ + using System; + using System.Collections.Generic; + using System.Collections.Concurrent; + using System.ComponentModel.DataAnnotations; + using System.Linq; + using Microsoft.Extensions.DependencyInjection; + using Microsoft.AspNetCore.Http.Validation; + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatablePropertyInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatablePropertyInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public GeneratedValidatablePropertyInfo( + Type containingType, + Type propertyType, + string name, + string displayName, + bool isEnumerable, + bool isNullable, + bool isRequired, + bool hasValidatableType, + ValidationAttribute[] validationAttributes) : base(containingType, propertyType, name, displayName, isEnumerable, isNullable, isRequired, hasValidatableType) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableParameterInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableParameterInfo + { + private readonly ValidationAttribute[] _validationAttributes; + + public GeneratedValidatableParameterInfo( + string name, + string displayName, + bool isNullable, + bool isRequired, + bool hasValidatableType, + bool isEnumerable, + ValidationAttribute[] validationAttributes) : base(name, displayName, isNullable, isRequired, hasValidatableType, isEnumerable) + { + _validationAttributes = validationAttributes; + } + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file sealed class GeneratedValidatableTypeInfo : global::Microsoft.AspNetCore.Http.Validation.ValidatableTypeInfo + { + public GeneratedValidatableTypeInfo( + Type type, + ValidatablePropertyInfo[] members, + bool implementsIValidatableObject, + Type[]? validatableSubTypes = null) : base(type, members, implementsIValidatableObject, validatableSubTypes) { } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file class GeneratedValidatableInfoResolver : global::Microsoft.AspNetCore.Http.Validation.IValidatableInfoResolver + { + public ValidatableTypeInfo? GetValidatableTypeInfo(Type type) + { + if (type == typeof(global::SubType)) + { + return CreateSubType(); + } + if (type == typeof(global::SubTypeWithInheritance)) + { + return CreateSubTypeWithInheritance(); + } + if (type == typeof(global::ComplexType)) + { + return CreateComplexType(); + } + + return null; + } + + public ValidatableParameterInfo? GetValidatableParameterInfo(global::System.Reflection.ParameterInfo parameterInfo) + { + + return null; + } + + private ValidatableTypeInfo CreateSubType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::SubType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "RequiredProperty", + displayName: "RequiredProperty", + isEnumerable: false, + isNullable: false, + isRequired: true, + hasValidatableType: false, + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute), Array.Empty(), new Dictionary()) ?? throw new InvalidOperationException("Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RequiredAttribute")]), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubType), + propertyType: typeof(string), + name: "StringWithLength", + displayName: "StringWithLength", + isEnumerable: false, + isNullable: false, + isRequired: false, + hasValidatableType: false, + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.StringLengthAttribute), new string[] { "10" }, new Dictionary()) ?? throw new InvalidOperationException("Failed to create validation attribute global::System.ComponentModel.DataAnnotations.StringLengthAttribute")]) + ], + implementsIValidatableObject: false); + } + private ValidatableTypeInfo CreateSubTypeWithInheritance() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::SubTypeWithInheritance), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::SubTypeWithInheritance), + propertyType: typeof(string), + name: "EmailString", + displayName: "EmailString", + isEnumerable: false, + isNullable: false, + isRequired: false, + hasValidatableType: false, + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.EmailAddressAttribute), Array.Empty(), new Dictionary()) ?? throw new InvalidOperationException("Failed to create validation attribute global::System.ComponentModel.DataAnnotations.EmailAddressAttribute")]) + ], + implementsIValidatableObject: false, + validatableSubTypes: [ + typeof(SubType) + ]); + } + private ValidatableTypeInfo CreateComplexType() + { + return new GeneratedValidatableTypeInfo( + type: typeof(global::ComplexType), + members: [ + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithRange", + displayName: "IntegerWithRange", + isEnumerable: false, + isNullable: false, + isRequired: false, + hasValidatableType: false, + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute), new string[] { "10", "100" }, new Dictionary()) ?? throw new InvalidOperationException("Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RangeAttribute")]), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithRangeAndDisplayName", + displayName: "Valid identifier", + isEnumerable: false, + isNullable: false, + isRequired: false, + hasValidatableType: false, + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute), new string[] { "10", "100" }, new Dictionary()) ?? throw new InvalidOperationException("Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RangeAttribute")]), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubType), + name: "PropertyWithMemberAttributes", + displayName: "PropertyWithMemberAttributes", + isEnumerable: false, + isNullable: false, + isRequired: true, + hasValidatableType: true, + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RequiredAttribute), Array.Empty(), new Dictionary()) ?? throw new InvalidOperationException("Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RequiredAttribute")]), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubType), + name: "PropertyWithoutMemberAttributes", + displayName: "PropertyWithoutMemberAttributes", + isEnumerable: false, + isNullable: false, + isRequired: false, + hasValidatableType: true, + validationAttributes: []), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::SubTypeWithInheritance), + name: "PropertyWithInheritance", + displayName: "PropertyWithInheritance", + isEnumerable: false, + isNullable: false, + isRequired: false, + hasValidatableType: true, + validationAttributes: []), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(global::System.Collections.Generic.List), + name: "ListOfSubTypes", + displayName: "ListOfSubTypes", + isEnumerable: true, + isNullable: false, + isRequired: false, + hasValidatableType: true, + validationAttributes: []), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "IntegerWithCustomValidationAttribute", + displayName: "IntegerWithCustomValidationAttribute", + isEnumerable: false, + isNullable: false, + isRequired: false, + hasValidatableType: false, + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::CustomValidationAttribute), Array.Empty(), new Dictionary { { "ErrorMessage", "Value must be an even number" } }) ?? throw new InvalidOperationException("Failed to create validation attribute global::CustomValidationAttribute")]), + new GeneratedValidatablePropertyInfo( + containingType: typeof(global::ComplexType), + propertyType: typeof(int), + name: "PropertyWithMultipleAttributes", + displayName: "PropertyWithMultipleAttributes", + isEnumerable: false, + isNullable: false, + isRequired: false, + hasValidatableType: false, + validationAttributes: [ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::CustomValidationAttribute), Array.Empty(), new Dictionary()) ?? throw new InvalidOperationException("Failed to create validation attribute global::CustomValidationAttribute"), ValidationAttributeCache.GetOrCreateValidationAttribute(typeof(global::System.ComponentModel.DataAnnotations.RangeAttribute), new string[] { "10", "100" }, new Dictionary()) ?? throw new InvalidOperationException("Failed to create validation attribute global::System.ComponentModel.DataAnnotations.RangeAttribute")]) + ], + implementsIValidatableObject: false); + } + + + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class GeneratedServiceCollectionExtensions + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute(1, "1zHOloYrguEmrREVCu+15nYBAABQcm9ncmFtLmNz")] + public static IServiceCollection AddValidation(this IServiceCollection services, Action? configureOptions = null) + { + // Use non-extension method to avoid infinite recursion. + return ValidationServiceCollectionExtensions.AddValidation(services, options => + { + options.Resolvers.Insert(0, new GeneratedValidatableInfoResolver()); + if (configureOptions is not null) + { + configureOptions(options); + } + }); + } + } + + [System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.AspNetCore.Http.ValidationsGenerator, Version=42.42.42.42, Culture=neutral, PublicKeyToken=adb9793829ddae60", "42.42.42.42")] + file static class ValidationAttributeCache + { + private sealed record CacheKey(Type AttributeType, string[] Arguments, IReadOnlyDictionary NamedArguments); + private static readonly ConcurrentDictionary _cache = new(); + + public static ValidationAttribute? GetOrCreateValidationAttribute( + Type attributeType, + string[] arguments, + IReadOnlyDictionary namedArguments) + { + var key = new CacheKey(attributeType, arguments, namedArguments); + return _cache.GetOrAdd(key, static k => + { + var type = k.AttributeType; + var args = k.Arguments; + + ValidationAttribute attribute; + + if (args.Length == 0) + { + attribute = type switch + { + Type t when t == typeof(RequiredAttribute) => new RequiredAttribute(), + Type t when t == typeof(EmailAddressAttribute) => new EmailAddressAttribute(), + Type t when t == typeof(PhoneAttribute) => new PhoneAttribute(), + Type t when t == typeof(UrlAttribute) => new UrlAttribute(), + Type t when t == typeof(CreditCardAttribute) => new CreditCardAttribute(), + _ when typeof(ValidationAttribute).IsAssignableFrom(type) => + (ValidationAttribute)Activator.CreateInstance(type)! + }; + } + else if (type == typeof(StringLengthAttribute)) + { + if (!int.TryParse(args[0], out var maxLength)) + throw new ArgumentException($"Invalid maxLength value for StringLengthAttribute: {args[0]}"); + attribute = new StringLengthAttribute(maxLength); + } + else if (type == typeof(MinLengthAttribute)) + { + if (!int.TryParse(args[0], out var length)) + throw new ArgumentException($"Invalid length value for MinLengthAttribute: {args[0]}"); + attribute = new MinLengthAttribute(length); + } + else if (type == typeof(MaxLengthAttribute)) + { + if (!int.TryParse(args[0], out var length)) + throw new ArgumentException($"Invalid length value for MaxLengthAttribute: {args[0]}"); + attribute = new MaxLengthAttribute(length); + } + else if (type == typeof(RangeAttribute) && args.Length == 2) + { + if (int.TryParse(args[0], out var min) && int.TryParse(args[1], out var max)) + attribute = new RangeAttribute(min, max); + else if (double.TryParse(args[0], out var dmin) && double.TryParse(args[1], out var dmax)) + attribute = new RangeAttribute(dmin, dmax); + else + throw new ArgumentException($"Invalid range values for RangeAttribute: {args[0]}, {args[1]}"); + } + else if (type == typeof(RegularExpressionAttribute)) + { + attribute = new RegularExpressionAttribute(args[0]); + } + else if (type == typeof(CompareAttribute)) + { + attribute = new CompareAttribute(args[0]); + } + else if (typeof(ValidationAttribute).IsAssignableFrom(type)) + { + var constructors = type.GetConstructors(); + var success = false; + attribute = null!; + + foreach (var constructor in constructors) + { + var parameters = constructor.GetParameters(); + if (parameters.Length != args.Length) + continue; + + var convertedArgs = new object[args.Length]; + var canUseConstructor = true; + + for (var i = 0; i < parameters.Length; i++) + { + try + { + convertedArgs[i] = Convert.ChangeType(args[i], parameters[i].ParameterType); + } + catch + { + canUseConstructor = false; + break; + } + } + + if (canUseConstructor) + { + attribute = (ValidationAttribute)Activator.CreateInstance(type, convertedArgs)!; + success = true; + break; + } + } + + if (!success) + { + throw new ArgumentException($"Could not find a suitable constructor for validation attribute type: {type.FullName}"); + } + } + else + { + throw new ArgumentException($"Unsupported validation attribute type: {type.FullName}"); + } + + // Apply named arguments after construction + foreach (var namedArg in k.NamedArguments) + { + var prop = type.GetProperty(namedArg.Key); + if (prop != null && prop.CanWrite) + { + try + { + var convertedValue = Convert.ChangeType(namedArg.Value, prop.PropertyType); + prop.SetValue(attribute, convertedValue); + } + catch (Exception ex) + { + throw new ArgumentException($"Failed to set property {namedArg.Key} on {type.FullName}: {ex.Message}"); + } + } + } + + return attribute; + }); + } + } +} \ No newline at end of file diff --git a/src/Http/Http/perf/Microbenchmarks/ValidatableTypesBenchmark.cs b/src/Http/Http/perf/Microbenchmarks/ValidatableTypesBenchmark.cs new file mode 100644 index 000000000000..68987c652fcd --- /dev/null +++ b/src/Http/Http/perf/Microbenchmarks/ValidatableTypesBenchmark.cs @@ -0,0 +1,365 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Http.Validation; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Http.Microbenchmarks; + +public class ValidatableTypeInfoBenchmark +{ + private IValidatableInfo _simpleTypeInfo = null!; + private IValidatableInfo _complexTypeInfo = null!; + private IValidatableInfo _hierarchicalTypeInfo = null!; + private IValidatableInfo _ivalidatableObjectTypeInfo = null!; + + private ValidatableContext _context = null!; + private SimpleModel _simpleModel = null!; + private ComplexModel _complexModel = null!; + private HierarchicalModel _hierarchicalModel = null!; + private ValidatableObjectModel _validatableObjectModel = null!; + + [GlobalSetup] + public void Setup() + { + var services = new ServiceCollection(); + var mockResolver = new MockValidatableTypeInfoResolver(); + + services.AddValidation(options => + { + // Register our mock resolver + options.Resolvers.Insert(0, mockResolver); + }); + + var serviceProvider = services.BuildServiceProvider(); + var validationOptions = serviceProvider.GetRequiredService>().Value; + + _context = new ValidatableContext + { + ValidationOptions = validationOptions, + ValidationContext = new ValidationContext(new object(), serviceProvider, null), + ValidationErrors = new Dictionary(StringComparer.Ordinal) + }; + + // Create the model instances + _simpleModel = new SimpleModel + { + Id = 1, + Name = "Test Name", + Email = "test@example.com" + }; + + _complexModel = new ComplexModel + { + Id = 1, + Name = "Complex Model", + Properties = new Dictionary + { + ["Prop1"] = "Value1", + ["Prop2"] = "Value2" + }, + Items = ["Item1", "Item2", "Item3"], + CreatedOn = DateTime.UtcNow + }; + + _hierarchicalModel = new HierarchicalModel + { + Id = 1, + Name = "Parent Model", + Child = new ChildModel + { + Id = 2, + Name = "Child Model", + ParentId = 1 + }, + Siblings = + [ + new SimpleModel { Id = 3, Name = "Sibling 1", Email = "sibling1@example.com" }, + new SimpleModel { Id = 4, Name = "Sibling 2", Email = "sibling2@example.com" } + ] + }; + + _validatableObjectModel = new ValidatableObjectModel + { + Id = 1, + Name = "Validatable Model", + CustomField = "Valid Value" + }; + + // Get the type info instances from validation options using the mock resolver + validationOptions.TryGetValidatableTypeInfo(typeof(SimpleModel), out _simpleTypeInfo); + validationOptions.TryGetValidatableTypeInfo(typeof(ComplexModel), out _complexTypeInfo); + validationOptions.TryGetValidatableTypeInfo(typeof(HierarchicalModel), out _hierarchicalTypeInfo); + validationOptions.TryGetValidatableTypeInfo(typeof(ValidatableObjectModel), out _ivalidatableObjectTypeInfo); + + // Ensure we have all type infos (this should not be needed with our mock resolver) + if (_simpleTypeInfo == null || _complexTypeInfo == null || + _hierarchicalTypeInfo == null || _ivalidatableObjectTypeInfo == null) + { + throw new InvalidOperationException("Failed to register one or more type infos with mock resolver"); + } + } + + [Benchmark(Description = "Validate Simple Model")] + [BenchmarkCategory("Simple")] + public async Task ValidateSimpleModel() + { + _context.ValidationErrors.Clear(); + await _simpleTypeInfo.ValidateAsync(_simpleModel, _context, default); + } + + [Benchmark(Description = "Validate Complex Model")] + [BenchmarkCategory("Complex")] + public async Task ValidateComplexModel() + { + _context.ValidationErrors.Clear(); + await _complexTypeInfo.ValidateAsync(_complexModel, _context, default); + } + + [Benchmark(Description = "Validate Hierarchical Model")] + [BenchmarkCategory("Hierarchical")] + public async Task ValidateHierarchicalModel() + { + _context.ValidationErrors.Clear(); + await _hierarchicalTypeInfo.ValidateAsync(_hierarchicalModel, _context, default); + } + + [Benchmark(Description = "Validate IValidatableObject Model")] + [BenchmarkCategory("IValidatableObject")] + public async Task ValidateIValidatableObjectModel() + { + _context.ValidationErrors.Clear(); + await _ivalidatableObjectTypeInfo.ValidateAsync(_validatableObjectModel, _context, default); + } + + [Benchmark(Description = "Validate invalid Simple Model")] + [BenchmarkCategory("Invalid")] + public async Task ValidateInvalidSimpleModel() + { + _context.ValidationErrors.Clear(); + _simpleModel.Email = "invalid-email"; + await _simpleTypeInfo.ValidateAsync(_simpleModel, _context, default); + } + + [Benchmark(Description = "Validate invalid IValidatableObject Model")] + [BenchmarkCategory("Invalid")] + public async Task ValidateInvalidIValidatableObjectModel() + { + _context.ValidationErrors.Clear(); + _validatableObjectModel.CustomField = "Invalid"; + await _ivalidatableObjectTypeInfo.ValidateAsync(_validatableObjectModel, _context, default); + } + + #region Helper methods to create type info instances manually if needed + + private ValidatablePropertyInfo CreatePropertyInfo(string name, Type type, params ValidationAttribute[] attributes) + { + return new MockValidatablePropertyInfo( + typeof(SimpleModel), + type, + name, + name, + attributes); + } + + #endregion + + #region Test Models + + public class SimpleModel + { + public int Id { get; set; } + + [Required] + public string Name { get; set; } + + [EmailAddress] + public string Email { get; set; } + } + + public class ComplexModel + { + public int Id { get; set; } + + [Required] + public string Name { get; set; } + + public Dictionary Properties { get; set; } + + public List Items { get; set; } + + public DateTime CreatedOn { get; set; } + } + + public class ChildModel + { + public int Id { get; set; } + + [Required] + public string Name { get; set; } + + public int ParentId { get; set; } + } + + public class HierarchicalModel + { + public int Id { get; set; } + + [Required] + public string Name { get; set; } + + public ChildModel Child { get; set; } + + public List Siblings { get; set; } + } + + public class ValidatableObjectModel : IValidatableObject + { + public int Id { get; set; } + + [Required] + public string Name { get; set; } + + public string CustomField { get; set; } + + public IEnumerable Validate(ValidationContext validationContext) + { + if (CustomField == "Invalid") + { + yield return new ValidationResult("CustomField has an invalid value", new[] { nameof(CustomField) }); + } + } + } + + #endregion + + #region Mock Implementations for Testing + + private class MockValidatableTypeInfo(Type type, ValidatablePropertyInfo[] members) : ValidatableTypeInfo(type, members) + { + } + + private class MockValidatablePropertyInfo( + Type containingType, + Type propertyType, + string name, + string displayName, + ValidationAttribute[] validationAttributes) : ValidatablePropertyInfo(containingType, propertyType, name, displayName) + { + private readonly ValidationAttribute[] _validationAttributes = validationAttributes; + + protected override ValidationAttribute[] GetValidationAttributes() => _validationAttributes; + } + + #endregion + + #region Mock Resolver Implementation + + private class MockValidatableTypeInfoResolver : IValidatableInfoResolver + { + private readonly Dictionary _typeInfoCache = []; + + public MockValidatableTypeInfoResolver() + { + // Initialize the cache with our test models + _typeInfoCache[typeof(SimpleModel)] = CreateSimpleModelTypeInfo(); + _typeInfoCache[typeof(ComplexModel)] = CreateComplexModelTypeInfo(); + _typeInfoCache[typeof(HierarchicalModel)] = CreateHierarchicalModelTypeInfo(); + _typeInfoCache[typeof(ValidatableObjectModel)] = CreateValidatableObjectModelTypeInfo(); + + // Add child models that might be validated separately + _typeInfoCache[typeof(ChildModel)] = CreateChildModelTypeInfo(); + } + + private ValidatableTypeInfo CreateSimpleModelTypeInfo() + { + return new MockValidatableTypeInfo( + typeof(SimpleModel), + [ + CreatePropertyInfo(typeof(SimpleModel), "Id", typeof(int)), + CreatePropertyInfo(typeof(SimpleModel), "Name", typeof(string)), + CreatePropertyInfo(typeof(SimpleModel), "Email", typeof(string), new EmailAddressAttribute()) + ]); + } + + private ValidatableTypeInfo CreateComplexModelTypeInfo() + { + return new MockValidatableTypeInfo( + typeof(ComplexModel), + [ + CreatePropertyInfo(typeof(ComplexModel), "Id", typeof(int)), + CreatePropertyInfo(typeof(ComplexModel), "Name", typeof(string)), + CreatePropertyInfo(typeof(ComplexModel), "Properties", typeof(Dictionary)), + CreatePropertyInfo(typeof(ComplexModel), "Items", typeof(List)), + CreatePropertyInfo(typeof(ComplexModel), "CreatedOn", typeof(DateTime)) + ]); + } + + private ValidatableTypeInfo CreateChildModelTypeInfo() + { + return new MockValidatableTypeInfo( + typeof(ChildModel), + [ + CreatePropertyInfo(typeof(ChildModel), "Id", typeof(int)), + CreatePropertyInfo(typeof(ChildModel), "Name", typeof(string)), + CreatePropertyInfo(typeof(ChildModel), "ParentId", typeof(int)) + ]); + } + + private ValidatableTypeInfo CreateHierarchicalModelTypeInfo() + { + return new MockValidatableTypeInfo( + typeof(HierarchicalModel), + [ + CreatePropertyInfo(typeof(HierarchicalModel), "Id", typeof(int)), + CreatePropertyInfo(typeof(HierarchicalModel), "Name", typeof(string)), + CreatePropertyInfo(typeof(HierarchicalModel), "Child", typeof(ChildModel)), + CreatePropertyInfo(typeof(HierarchicalModel), "Siblings", typeof(List)) + ]); + } + + private ValidatableTypeInfo CreateValidatableObjectModelTypeInfo() + { + return new MockValidatableTypeInfo( + typeof(ValidatableObjectModel), + [ + CreatePropertyInfo(typeof(ValidatableObjectModel), "Id", typeof(int)), + CreatePropertyInfo(typeof(ValidatableObjectModel), "Name", typeof(string)), + CreatePropertyInfo(typeof(ValidatableObjectModel), "CustomField", typeof(string)) + ]); + } + + private ValidatablePropertyInfo CreatePropertyInfo(Type containingType, string name, Type type, params ValidationAttribute[] attributes) + { + return new MockValidatablePropertyInfo( + containingType, + type, + name, + name, // Use name as display name + attributes); + } + + public bool TryGetValidatableTypeInfo(Type type, out IValidatableInfo validatableInfo) + { + if (_typeInfoCache.TryGetValue(type, out var typeInfo)) + { + validatableInfo = typeInfo; + return true; + } + validatableInfo = null; + return false; + } + + public bool TryGetValidatableParameterInfo(ParameterInfo parameterInfo, out IValidatableInfo validatableInfo) + { + validatableInfo = null; + return false; + } + } + #endregion +} diff --git a/src/Http/Routing/src/Builder/ValidationRouteHandlerBuilderExtensions.cs b/src/Http/Routing/src/Builder/ValidationRouteHandlerBuilderExtensions.cs new file mode 100644 index 000000000000..6ab1792eec99 --- /dev/null +++ b/src/Http/Routing/src/Builder/ValidationRouteHandlerBuilderExtensions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Http.Metadata; + +namespace Microsoft.AspNetCore.Builder; + +/// +/// Extension methods for to interact with +/// parameter validation features. +/// +public static class ValidationEndpointConventionBuilderExtensions +{ + /// + /// Disables validation for the specified endpoint. + /// + /// The type of the builder. + /// The endpoint convention builder. + /// + /// The for chaining. + /// + public static TBuilder DisableValidation(this TBuilder builder) + where TBuilder : IEndpointConventionBuilder + { + builder.WithMetadata(new DisableValidationMetadata()); + return builder; + } + + private sealed class DisableValidationMetadata : IDisableValidationMetadata + { + } +} diff --git a/src/Http/Routing/src/PublicAPI.Unshipped.txt b/src/Http/Routing/src/PublicAPI.Unshipped.txt index 7dc5c58110bf..0612dc9ff2b0 100644 --- a/src/Http/Routing/src/PublicAPI.Unshipped.txt +++ b/src/Http/Routing/src/PublicAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.AspNetCore.Builder.ValidationEndpointConventionBuilderExtensions +static Microsoft.AspNetCore.Builder.ValidationEndpointConventionBuilderExtensions.DisableValidation(this TBuilder builder) -> TBuilder diff --git a/src/Http/Routing/src/RouteEndpointDataSource.cs b/src/Http/Routing/src/RouteEndpointDataSource.cs index 2ed6ff242276..a4e993c4aaa3 100644 --- a/src/Http/Routing/src/RouteEndpointDataSource.cs +++ b/src/Http/Routing/src/RouteEndpointDataSource.cs @@ -2,12 +2,16 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Linq; using System.Reflection; using System.Runtime.CompilerServices; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Metadata; +using Microsoft.AspNetCore.Http.Validation; using Microsoft.AspNetCore.Routing.Patterns; using Microsoft.Extensions.FileProviders; +using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Routing; @@ -59,6 +63,8 @@ public RouteHandlerBuilder AddRouteHandler( Func createHandlerRequestDelegateFunc, MethodInfo methodInfo) { + // Initialize all route handlers with validation convention if validation options + // are registered. var conventions = new ThrowOnAddAfterEndpointBuiltConventionCollection(); var finallyConventions = new ThrowOnAddAfterEndpointBuiltConventionCollection(); @@ -100,7 +106,7 @@ public override IReadOnlyList Endpoints public override IReadOnlyList GetGroupedEndpoints(RouteGroupContext context) { var endpoints = new RouteEndpoint[_routeEntries.Count]; - for (int i = 0; i < _routeEntries.Count; i++) + for (var i = 0; i < _routeEntries.Count; i++) { endpoints[i] = (RouteEndpoint)CreateRouteEndpointBuilder(_routeEntries[i], context.Prefix, context.Conventions, context.FinallyConventions).Build(); } @@ -155,7 +161,7 @@ private RouteEndpointBuilder CreateRouteEndpointBuilder( // If we're not a route handler, we started with a fully realized (although unfiltered) RequestDelegate, so we can just redirect to that // while running any conventions. We'll put the original back if it remains unfiltered right before building the endpoint. - RequestDelegate? factoryCreatedRequestDelegate = isRouteHandler ? null : (RequestDelegate)entry.RouteHandler; + var factoryCreatedRequestDelegate = isRouteHandler ? null : (RequestDelegate)entry.RouteHandler; // Let existing conventions capture and call into builder.RequestDelegate as long as they do so after it has been created. RequestDelegate redirectRequestDelegate = context => @@ -232,6 +238,13 @@ private RouteEndpointBuilder CreateRouteEndpointBuilder( entrySpecificConvention(builder); } + var hasValidationOptions = builder.ApplicationServices.GetService(typeof(IOptions)) is not null; + var hasDisableValidationMetadata = builder.Metadata.OfType().FirstOrDefault() is not null; + if (hasValidationOptions && !hasDisableValidationMetadata) + { + builder.FilterFactories.Insert(0, ValidationEndpointFilterFactory.Create); + } + // If no convention has modified builder.RequestDelegate, we can use the RequestDelegate returned by the RequestDelegateFactory directly. var conventionOverriddenRequestDelegate = ReferenceEquals(builder.RequestDelegate, redirectRequestDelegate) ? null : builder.RequestDelegate; diff --git a/src/Http/Routing/src/ValidationEndpointFilterFactory.cs b/src/Http/Routing/src/ValidationEndpointFilterFactory.cs new file mode 100644 index 000000000000..ced4c0e90125 --- /dev/null +++ b/src/Http/Routing/src/ValidationEndpointFilterFactory.cs @@ -0,0 +1,100 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Http.Validation; + +internal static class ValidationEndpointFilterFactory +{ + private const string ValidationContextJustification = "The DisplayName property is always statically initialized in the ValidationContext through this codepath."; + + public static EndpointFilterDelegate Create(EndpointFilterFactoryContext context, EndpointFilterDelegate next) + { + var parameters = context.MethodInfo.GetParameters(); + var options = context.ApplicationServices.GetService>()?.Value; + if (options is null) + { + return next; + } + + var parameterCount = parameters.Length; + var validatableParameters = new IValidatableInfo[parameterCount]; + var parameterDisplayNames = new string[parameterCount]; + var hasValidatableParameters = false; + + for (var i = 0; i < parameterCount; i++) + { + if (options.TryGetValidatableParameterInfo(parameters[i], out var validatableParameter)) + { + validatableParameters[i] = validatableParameter; + parameterDisplayNames[i] = GetDisplayName(parameters[i]); + hasValidatableParameters = true; + } + } + + if (!hasValidatableParameters) + { + return next; + } + + var validatableContext = new ValidatableContext { ValidationOptions = options }; + return async (context) => + { + validatableContext.ValidationErrors?.Clear(); + + for (var i = 0; i < context.Arguments.Count; i++) + { + var validatableParameter = validatableParameters[i]; + var displayName = parameterDisplayNames[i]; + + var argument = context.Arguments[i]; + if (argument is null || validatableParameter is null) + { + continue; + } + // ValidationContext is not trim-friendly in codepaths that don't + // initialize an explicit DisplayName. We can suppress the warning here. + // Eventually, this can be removed when the code is updated to + // use https://github.com/dotnet/runtime/issues/113134. + var validationContext = CreateValidationContext(argument, displayName, context.HttpContext.RequestServices); + validatableContext.ValidationContext = validationContext; + await validatableParameter.ValidateAsync(argument, validatableContext, context.HttpContext.RequestAborted); + } + + if (validatableContext.ValidationErrors is { Count: > 0 }) + { + context.HttpContext.Response.StatusCode = StatusCodes.Status400BadRequest; + context.HttpContext.Response.ContentType = "application/problem+json"; + return await ValueTask.FromResult(new HttpValidationProblemDetails(validatableContext.ValidationErrors)); + } + + return await next(context); + }; + } + + /// + /// ValidationContext is not trim-friendly in codepaths that don't + /// initialize an explicit DisplayName. We can suppress the warning here. + /// Eventually, this can be removed when the code is updated to + /// use https://github.com/dotnet/runtime/issues/113134. + /// + [UnconditionalSuppressMessage("Trimming", "IL2026:Members annotated with 'RequiresUnreferencedCodeAttribute' require dynamic access otherwise can break functionality when trimming application code", Justification = ValidationContextJustification)] + private static ValidationContext CreateValidationContext(object argument, string displayName, IServiceProvider serviceProvider) + => new(argument, serviceProvider, items: null) { DisplayName = displayName }; + + private static string GetDisplayName(ParameterInfo parameterInfo) + { + var displayAttribute = parameterInfo.GetCustomAttribute(); + if (displayAttribute != null) + { + return displayAttribute.Name ?? parameterInfo.Name!; + } + + return parameterInfo.Name!; + } +} diff --git a/src/Http/samples/MinimalValidationSample/MinimalValidationSample.csproj b/src/Http/samples/MinimalValidationSample/MinimalValidationSample.csproj new file mode 100644 index 000000000000..51a1eb16576f --- /dev/null +++ b/src/Http/samples/MinimalValidationSample/MinimalValidationSample.csproj @@ -0,0 +1,24 @@ + + + + $(DefaultNetCoreTargetFramework) + enable + true + + + + + + + + + + + + + + + + + + diff --git a/src/Http/samples/MinimalValidationSample/Program.cs b/src/Http/samples/MinimalValidationSample/Program.cs new file mode 100644 index 000000000000..e6bc26af771e --- /dev/null +++ b/src/Http/samples/MinimalValidationSample/Program.cs @@ -0,0 +1,108 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.ComponentModel.DataAnnotations; +using Microsoft.AspNetCore.Http.Validation; + +var builder = WebApplication.CreateBuilder(args); + +builder.Services.AddValidation(); + +var app = builder.Build(); + +// ValidationEndpointFilterFactory is implicitly enabled on all endpoints +app.MapGet("/customers/{id}", ([Range(1, int.MaxValue)] int id) => + $"Getting customer with ID: {id}"); + +app.MapPost("/customers", (Customer customer) => +{ + // Validation happens automatically before this code runs + return TypedResults.Created($"/customers/{customer.Name}", customer); +}); + +app.MapPost("/orders", (Order order) => +{ + // Both attribute validation and IValidatableObject.Validate are called automatically + return TypedResults.Created($"/orders/{order.OrderId}", order); +}); + +app.MapPost("/products", ([EvenNumberAttribute(ErrorMessage = "Product ID must be even")] int productId, + [Required] string name) => +{ + return TypedResults.Ok(new { productId, name }); +}) +.DisableValidation(); + +app.Run(); + +// Define validatable types with the ValidatableType attribute +[ValidatableType] +public class Customer +{ + [Required] + public required string Name { get; set; } + + [EmailAddress] + public required string Email { get; set; } + + [Range(18, 120)] + [Display(Name = "Customer Age")] + public int Age { get; set; } + + // Complex property with nested validation + public Address HomeAddress { get; set; } = new Address + { + Street = "123 Main St", + City = "Anytown", + ZipCode = "12345" + }; +} + +public class Address +{ + [Required] + public required string Street { get; set; } + + [Required] + public required string City { get; set; } + + [StringLength(5)] + public required string ZipCode { get; set; } +} + +// Define a type implementing IValidatableObject for custom validation +[ValidatableType] +public class Order : IValidatableObject +{ + [Range(1, int.MaxValue)] + public int OrderId { get; set; } + + [Required] + public required string ProductName { get; set; } + + public int Quantity { get; set; } + + // Custom validation logic using IValidatableObject + public IEnumerable Validate(ValidationContext validationContext) + { + if (Quantity <= 0) + { + yield return new ValidationResult( + "Quantity must be greater than zero", + [nameof(Quantity)]); + } + } +} + +// Use a custom validation attribute +public class EvenNumberAttribute : ValidationAttribute +{ + public override bool IsValid(object? value) + { + if (value is int number) + { + return number % 2 == 0; + } + return false; + } +} diff --git a/src/Http/samples/MinimalValidationSample/Properties/launchSettings.json b/src/Http/samples/MinimalValidationSample/Properties/launchSettings.json new file mode 100644 index 000000000000..6e42095c6bd3 --- /dev/null +++ b/src/Http/samples/MinimalValidationSample/Properties/launchSettings.json @@ -0,0 +1,13 @@ +{ + "profiles": { + "HttpApiSampleApp": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": true, + "applicationUrl": "https://localhost:5022;http://localhost:5021", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} diff --git a/src/Http/samples/MinimalValidationSample/appsettings.Development.json b/src/Http/samples/MinimalValidationSample/appsettings.Development.json new file mode 100644 index 000000000000..8983e0fc1c5e --- /dev/null +++ b/src/Http/samples/MinimalValidationSample/appsettings.Development.json @@ -0,0 +1,9 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft": "Warning", + "Microsoft.Hosting.Lifetime": "Information" + } + } +} diff --git a/src/Http/samples/MinimalValidationSample/appsettings.json b/src/Http/samples/MinimalValidationSample/appsettings.json new file mode 100644 index 000000000000..d9d9a9bff6fd --- /dev/null +++ b/src/Http/samples/MinimalValidationSample/appsettings.json @@ -0,0 +1,10 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Information", + "Microsoft": "Warning", + "Microsoft.Hosting.Lifetime": "Information" + } + }, + "AllowedHosts": "*" +}