Skip to content

Commit c22a853

Browse files
authored
Fix handling of record types in validations source generator (#61402)
* Fix handling of record types in validations source generator * Address feedback and add tests
1 parent 183c128 commit c22a853

13 files changed

+984
-39
lines changed

Diff for: src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Emitters/ValidationsGenerator.Emitter.cs

+30-4
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public bool TryGetValidatableParameterInfo(global::System.Reflection.ParameterIn
106106
file static class GeneratedServiceCollectionExtensions
107107
{
108108
{{addValidation.GetInterceptsLocationAttributeSyntax()}}
109-
public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action<ValidationOptions>? configureOptions = null)
109+
public static global::Microsoft.Extensions.DependencyInjection.IServiceCollection AddValidation(this global::Microsoft.Extensions.DependencyInjection.IServiceCollection services, global::System.Action<global::Microsoft.AspNetCore.Http.Validation.ValidationOptions>? configureOptions = null)
110110
{
111111
// Use non-extension method to avoid infinite recursion.
112112
return global::Microsoft.Extensions.DependencyInjection.ValidationServiceCollectionExtensions.AddValidation(services, options =>
@@ -133,13 +133,39 @@ private sealed record CacheKey(global::System.Type ContainingType, string Proper
133133
var key = new CacheKey(containingType, propertyName);
134134
return _cache.GetOrAdd(key, static k =>
135135
{
136+
var results = new global::System.Collections.Generic.List<global::System.ComponentModel.DataAnnotations.ValidationAttribute>();
137+
138+
// Get attributes from the property
136139
var property = k.ContainingType.GetProperty(k.PropertyName);
137-
if (property == null)
140+
if (property != null)
141+
{
142+
var propertyAttributes = global::System.Reflection.CustomAttributeExtensions
143+
.GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(property, inherit: true);
144+
145+
results.AddRange(propertyAttributes);
146+
}
147+
148+
// Check constructors for parameters that match the property name
149+
// to handle record scenarios
150+
foreach (var constructor in k.ContainingType.GetConstructors())
138151
{
139-
return [];
152+
// Look for parameter with matching name (case insensitive)
153+
var parameter = global::System.Linq.Enumerable.FirstOrDefault(
154+
constructor.GetParameters(),
155+
p => string.Equals(p.Name, k.PropertyName, global::System.StringComparison.OrdinalIgnoreCase));
156+
157+
if (parameter != null)
158+
{
159+
var paramAttributes = global::System.Reflection.CustomAttributeExtensions
160+
.GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(parameter, inherit: true);
161+
162+
results.AddRange(paramAttributes);
163+
164+
break;
165+
}
140166
}
141167
142-
return [.. global::System.Reflection.CustomAttributeExtensions.GetCustomAttributes<global::System.ComponentModel.DataAnnotations.ValidationAttribute>(property, inherit: true)];
168+
return results.ToArray();
143169
});
144170
}
145171
}

Diff for: src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Extensions/ITypeSymbolExtensions.cs

+21
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33

44
using System.Collections.Immutable;
5+
using System.Linq;
56
using Microsoft.CodeAnalysis;
67

78
namespace Microsoft.AspNetCore.Http.ValidationsGenerator;
@@ -101,4 +102,24 @@ internal static bool IsExemptType(this ITypeSymbol type, RequiredSymbols require
101102
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.Stream)
102103
|| SymbolEqualityComparer.Default.Equals(type, requiredSymbols.PipeReader);
103104
}
105+
106+
internal static IPropertySymbol? FindPropertyIncludingBaseTypes(this INamedTypeSymbol typeSymbol, string propertyName)
107+
{
108+
var property = typeSymbol.GetMembers()
109+
.OfType<IPropertySymbol>()
110+
.FirstOrDefault(p => string.Equals(p.Name, propertyName, System.StringComparison.OrdinalIgnoreCase));
111+
112+
if (property != null)
113+
{
114+
return property;
115+
}
116+
117+
// If not found, recursively search base types
118+
if (typeSymbol.BaseType is INamedTypeSymbol baseType)
119+
{
120+
return FindPropertyIncludingBaseTypes(baseType, propertyName);
121+
}
122+
123+
return null;
124+
}
104125
}

Diff for: src/Http/Http.Extensions/gen/Microsoft.AspNetCore.Http.ValidationsGenerator/Parsers/ValidationsGenerator.TypesParser.cs

+59
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,74 @@ internal bool TryExtractValidatableType(ITypeSymbol typeSymbol, RequiredSymbols
8989
internal ImmutableArray<ValidatableProperty> ExtractValidatableMembers(ITypeSymbol typeSymbol, RequiredSymbols requiredSymbols, ref HashSet<ValidatableType> validatableTypes, ref List<ITypeSymbol> visitedTypes)
9090
{
9191
var members = new List<ValidatableProperty>();
92+
var resolvedRecordProperty = new List<IPropertySymbol>();
93+
94+
// Special handling for record types to extract properties from
95+
// the primary constructor.
96+
if (typeSymbol is INamedTypeSymbol { IsRecord: true } namedType)
97+
{
98+
// Find the primary constructor for the record, account
99+
// for members that are in base types to account for
100+
// record inheritance scenarios
101+
var primaryConstructor = namedType.Constructors
102+
.FirstOrDefault(c => c.Parameters.Length > 0 && c.Parameters.All(p =>
103+
namedType.FindPropertyIncludingBaseTypes(p.Name) != null));
104+
105+
if (primaryConstructor != null)
106+
{
107+
// Process all parameters in constructor order to maintain parameter ordering
108+
foreach (var parameter in primaryConstructor.Parameters)
109+
{
110+
// Find the corresponding property in this type, we ignore
111+
// base types here since that will be handled by the inheritance
112+
// checks in the default ValidatableTypeInfo implementation.
113+
var correspondingProperty = typeSymbol.GetMembers()
114+
.OfType<IPropertySymbol>()
115+
.FirstOrDefault(p => string.Equals(p.Name, parameter.Name, System.StringComparison.OrdinalIgnoreCase));
116+
117+
if (correspondingProperty != null)
118+
{
119+
resolvedRecordProperty.Add(correspondingProperty);
120+
121+
// Check if the property's type is validatable, this resolves
122+
// validatable types in the inheritance hierarchy
123+
var hasValidatableType = TryExtractValidatableType(
124+
correspondingProperty.Type.UnwrapType(requiredSymbols.IEnumerable),
125+
requiredSymbols,
126+
ref validatableTypes,
127+
ref visitedTypes);
128+
129+
members.Add(new ValidatableProperty(
130+
ContainingType: correspondingProperty.ContainingType,
131+
Type: correspondingProperty.Type,
132+
Name: correspondingProperty.Name,
133+
DisplayName: parameter.GetDisplayName(requiredSymbols.DisplayAttribute) ??
134+
correspondingProperty.GetDisplayName(requiredSymbols.DisplayAttribute),
135+
Attributes: []));
136+
}
137+
}
138+
}
139+
}
140+
141+
// Handle properties for classes and any properties not handled by the constructor
92142
foreach (var member in typeSymbol.GetMembers().OfType<IPropertySymbol>())
93143
{
144+
// Skip compiler generated properties and properties already processed via
145+
// the record processing logic above.
146+
if (member.IsImplicitlyDeclared || resolvedRecordProperty.Contains(member, SymbolEqualityComparer.Default))
147+
{
148+
continue;
149+
}
150+
94151
var hasValidatableType = TryExtractValidatableType(member.Type.UnwrapType(requiredSymbols.IEnumerable), requiredSymbols, ref validatableTypes, ref visitedTypes);
95152
var attributes = ExtractValidationAttributes(member, requiredSymbols, out var isRequired);
153+
96154
// If the member has no validation attributes or validatable types and is not required, skip it.
97155
if (attributes.IsDefaultOrEmpty && !hasValidatableType && !isRequired)
98156
{
99157
continue;
100158
}
159+
101160
members.Add(new ValidatableProperty(
102161
ContainingType: member.ContainingType,
103162
Type: member.Type,

Diff for: src/Http/Http.Extensions/test/ValidationsGenerator/ValidationsGenerator.ComplexType.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ public class ComplexType
3939
public int IntegerWithRangeAndDisplayName { get; set; } = 50;
4040
4141
[Required]
42-
public SubType PropertyWithMemberAttributes { get; set; } = new SubType();
42+
public SubType PropertyWithMemberAttributes { get; set; } = new SubType("some-value", default);
4343
44-
public SubType PropertyWithoutMemberAttributes { get; set; } = new SubType();
44+
public SubType PropertyWithoutMemberAttributes { get; set; } = new SubType("some-value", default);
4545
46-
public SubTypeWithInheritance PropertyWithInheritance { get; set; } = new SubTypeWithInheritance();
46+
public SubTypeWithInheritance PropertyWithInheritance { get; set; } = new SubTypeWithInheritance("some-value", default);
4747
4848
public List<SubType> ListOfSubTypes { get; set; } = [];
4949
@@ -62,16 +62,16 @@ public class DerivedValidationAttribute : ValidationAttribute
6262
public override bool IsValid(object? value) => value is int number && number % 2 == 0;
6363
}
6464
65-
public class SubType
65+
public class SubType(string? requiredProperty, string? stringWithLength)
6666
{
6767
[Required]
68-
public string RequiredProperty { get; set; } = "some-value";
68+
public string RequiredProperty { get; } = requiredProperty;
6969
7070
[StringLength(10)]
71-
public string? StringWithLength { get; set; }
71+
public string? StringWithLength { get; } = stringWithLength;
7272
}
7373
74-
public class SubTypeWithInheritance : SubType
74+
public class SubTypeWithInheritance(string? requiredProperty, string? stringWithLength) : SubType(requiredProperty, stringWithLength)
7575
{
7676
[EmailAddress]
7777
public string? EmailString { get; set; }

0 commit comments

Comments
 (0)