From ed0017ba687aa4f32701126121665eda8c9b63d7 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 11:51:35 -0500 Subject: [PATCH 1/6] Refactor. --- .../AddEndpointHandlersGenerator.cs | 50 ++--------- .../Common/EndpointConfiguration.cs | 4 +- .../Common/EndpointConfigurationFactory.cs | 11 ++- .../Common/EndpointGroup.cs | 8 ++ .../Common/RequestHandlerClass.cs | 47 +++++++++-- ...eEntry.cs => RequestHandlerClassHelper.cs} | 43 +++++----- .../Common/RequestHandlerMethod.cs | 45 ++++++++-- .../Common/RequestHandlerMethodHelper.cs | 25 ++++++ .../Common/StringExtensions.cs | 9 ++ src/GeneratedEndpoints/MinimalApiGenerator.cs | 75 +++++------------ .../UseEndpointHandlersGenerator.cs | 84 ++++++++++--------- 11 files changed, 219 insertions(+), 182 deletions(-) create mode 100644 src/GeneratedEndpoints/Common/EndpointGroup.cs rename src/GeneratedEndpoints/Common/{RequestHandlerClassCacheEntry.cs => RequestHandlerClassHelper.cs} (81%) create mode 100644 src/GeneratedEndpoints/Common/RequestHandlerMethodHelper.cs diff --git a/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs index 2da92c7..bb258ee 100644 --- a/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs +++ b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs @@ -1,4 +1,5 @@ -using System.Text; +using System.Collections.Immutable; +using System.Text; using GeneratedEndpoints.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Text; @@ -12,12 +13,12 @@ namespace GeneratedEndpoints; internal static class AddEndpointHandlersGenerator { - public static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) + public static void GenerateSource(SourceProductionContext context, ImmutableSortedDictionary> grouped) { context.CancellationToken.ThrowIfCancellationRequested(); - var nonStaticClassNames = GetDistinctNonStaticClassNames(requestHandlers); - var source = GetAddEndpointHandlersStringBuilder(nonStaticClassNames); + var nonStaticClassNames = grouped.Keys.Where(x => !x.IsStatic).Select(x => x.Name).ToList(); + var source = new StringBuilder(); source.AppendLine(FileHeader); source.AppendLine(); @@ -61,45 +62,4 @@ public static void GenerateSource(SourceProductionContext context, EquatableImmu var sourceText = StringBuilderPool.ToStringAndReturn(source); context.AddSource(AddEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); } - - private static List GetDistinctNonStaticClassNames(EquatableImmutableArray requestHandlers) - { - var classNames = new List(); - if (requestHandlers.Count == 0) - return classNames; - - var seen = new HashSet(StringComparer.Ordinal); - for (var index = 0; index < requestHandlers.Count; index++) - { - var requestHandler = requestHandlers[index]; - if (requestHandler.Class.IsStatic) - continue; - - var className = requestHandler.Class.Name; - if (seen.Add(className)) - classNames.Add(className); - } - - return classNames; - } - - private static StringBuilder GetAddEndpointHandlersStringBuilder(List nonStaticClassNames) - { - var estimate = 512L; - for (var index = 0; index < nonStaticClassNames.Count; index++) - { - var className = nonStaticClassNames[index]; - estimate += 36 + className.Length; - } - - estimate += Math.Max(256, nonStaticClassNames.Count * 12); - estimate = (long)(estimate * 1.10); - - if (estimate < 512) - estimate = 512; - else if (estimate > int.MaxValue) - estimate = int.MaxValue; - - return StringBuilderPool.Get((int)estimate); - } } diff --git a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs index 60c5bbc..3cf41fb 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs @@ -27,7 +27,5 @@ internal readonly record struct EndpointConfiguration public required bool WithRequestTimeout { get; init; } public required string? RequestTimeoutPolicyName { get; init; } public required int? Order { get; init; } - public required string? GroupIdentifier { get; init; } - public required string? GroupPattern { get; init; } - public required string? GroupName { get; init; } + public required EndpointGroup? Group { get; init; } } diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index f986ac9..9bbcd49 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -172,9 +172,14 @@ public static EndpointConfiguration Create(ISymbol symbol) WithRequestTimeout = withRequestTimeout ?? false, RequestTimeoutPolicyName = requestTimeoutPolicyName, Order = order, - GroupIdentifier = groupIdentifier, - GroupPattern = groupPattern, - GroupName = groupName, + Group = groupIdentifier is not null && groupPattern is not null + ? new EndpointGroup + { + Identifier = groupIdentifier, + Pattern = groupPattern, + Name = groupName, + } + : null, }; } diff --git a/src/GeneratedEndpoints/Common/EndpointGroup.cs b/src/GeneratedEndpoints/Common/EndpointGroup.cs new file mode 100644 index 0000000..259a904 --- /dev/null +++ b/src/GeneratedEndpoints/Common/EndpointGroup.cs @@ -0,0 +1,8 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct EndpointGroup +{ + public required string Identifier { get; init; } + public required string Pattern { get; init; } + public required string? Name { get; init; } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClass.cs b/src/GeneratedEndpoints/Common/RequestHandlerClass.cs index c5e868b..e3c8336 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClass.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClass.cs @@ -1,9 +1,42 @@ namespace GeneratedEndpoints.Common; -internal readonly record struct RequestHandlerClass( - string Name, - bool IsStatic, - bool HasConfigureMethod, - bool ConfigureMethodAcceptsServiceProvider, - EndpointConfiguration Configuration -); +internal readonly record struct RequestHandlerClass : IComparable, IComparable +{ + public required string Name { get; init; } + public required bool IsStatic { get; init; } + public required bool HasConfigureMethod { get; init; } + public required bool ConfigureMethodAcceptsServiceProvider { get; init; } + public required EndpointConfiguration Configuration { get; init; } + + public int CompareTo(RequestHandlerClass other) + { + return string.Compare(Name, other.Name, StringComparison.Ordinal); + } + + public int CompareTo(object? obj) + { + if (obj is null) + return 1; + return obj is RequestHandlerClass other ? CompareTo(other) : throw new ArgumentException($"Object must be of type {nameof(RequestHandlerClass)}"); + } + + public static bool operator <(RequestHandlerClass left, RequestHandlerClass right) + { + return left.CompareTo(right) < 0; + } + + public static bool operator >(RequestHandlerClass left, RequestHandlerClass right) + { + return left.CompareTo(right) > 0; + } + + public static bool operator <=(RequestHandlerClass left, RequestHandlerClass right) + { + return left.CompareTo(right) <= 0; + } + + public static bool operator >=(RequestHandlerClass left, RequestHandlerClass right) + { + return left.CompareTo(right) >= 0; + } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassHelper.cs similarity index 81% rename from src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs rename to src/GeneratedEndpoints/Common/RequestHandlerClassHelper.cs index 3226132..74df35c 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassHelper.cs @@ -4,36 +4,31 @@ namespace GeneratedEndpoints.Common; -internal sealed class RequestHandlerClassCacheEntry +internal static class RequestHandlerClassHelper { - private readonly object _lock = new(); - private RequestHandlerClass _value; - private bool _initialized; - - public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, CancellationToken cancellationToken) + public static RequestHandlerClass? Create(IMethodSymbol methodSymbol, CancellationToken cancellationToken) { - if (_initialized) - return _value; - - lock (_lock) - { - if (_initialized) - return _value; + cancellationToken.ThrowIfCancellationRequested(); - cancellationToken.ThrowIfCancellationRequested(); + var classSymbol = methodSymbol.ContainingType; + if (classSymbol.TypeKind != TypeKind.Class) + return null; - var name = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var isStatic = classSymbol.IsStatic; - var configureMethodDetails = GetConfigureMethodDetails(classSymbol, cancellationToken); + var name = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var isStatic = classSymbol.IsStatic; + var configureMethodDetails = GetConfigureMethodDetails(classSymbol, cancellationToken); + var classConfiguration = EndpointConfigurationFactory.Create(classSymbol); - var classConfiguration = EndpointConfigurationFactory.Create(classSymbol); + var requestHandlerClass = new RequestHandlerClass + { + Name = name, + IsStatic = isStatic, + HasConfigureMethod = configureMethodDetails.HasConfigureMethod, + ConfigureMethodAcceptsServiceProvider = configureMethodDetails.ConfigureMethodAcceptsServiceProvider, + Configuration = classConfiguration, + }; - _value = new RequestHandlerClass(name, isStatic, configureMethodDetails.HasConfigureMethod, - configureMethodDetails.ConfigureMethodAcceptsServiceProvider, classConfiguration - ); - _initialized = true; - return _value; - } + return requestHandlerClass; } private static ConfigureMethodDetails GetConfigureMethodDetails(INamedTypeSymbol classSymbol, CancellationToken cancellationToken) diff --git a/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs index 54e2958..52669c4 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs @@ -1,8 +1,41 @@ namespace GeneratedEndpoints.Common; -internal readonly record struct RequestHandlerMethod( - string Name, - bool IsStatic, - EquatableImmutableArray Parameters, - EndpointConfiguration Configuration -); +internal readonly record struct RequestHandlerMethod : IComparable, IComparable +{ + public required string Name { get; init; } + public required bool IsStatic { get; init; } + public required EquatableImmutableArray Parameters { get; init; } + public required EndpointConfiguration Configuration { get; init; } + + public int CompareTo(RequestHandlerMethod other) + { + return string.Compare(Name, other.Name, StringComparison.Ordinal); + } + + public int CompareTo(object? obj) + { + if (obj is null) + return 1; + return obj is RequestHandlerMethod other ? CompareTo(other) : throw new ArgumentException($"Object must be of type {nameof(RequestHandlerMethod)}"); + } + + public static bool operator <(RequestHandlerMethod left, RequestHandlerMethod right) + { + return left.CompareTo(right) < 0; + } + + public static bool operator >(RequestHandlerMethod left, RequestHandlerMethod right) + { + return left.CompareTo(right) > 0; + } + + public static bool operator <=(RequestHandlerMethod left, RequestHandlerMethod right) + { + return left.CompareTo(right) <= 0; + } + + public static bool operator >=(RequestHandlerMethod left, RequestHandlerMethod right) + { + return left.CompareTo(right) >= 0; + } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerMethodHelper.cs b/src/GeneratedEndpoints/Common/RequestHandlerMethodHelper.cs new file mode 100644 index 0000000..72193dc --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerMethodHelper.cs @@ -0,0 +1,25 @@ +using Microsoft.CodeAnalysis; + +namespace GeneratedEndpoints.Common; + +internal static class RequestHandlerMethodHelper +{ + public static RequestHandlerMethod Create(IMethodSymbol methodSymbol, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + var name = methodSymbol.Name; + var isStatic = methodSymbol.IsStatic; + var parameters = methodSymbol.GetParameters(cancellationToken); + var configuration = EndpointConfigurationFactory.Create(methodSymbol); + var requestHandlerMethod = new RequestHandlerMethod + { + Name = name, + IsStatic = isStatic, + Parameters = parameters, + Configuration = configuration, + }; + + return requestHandlerMethod; + } +} diff --git a/src/GeneratedEndpoints/Common/StringExtensions.cs b/src/GeneratedEndpoints/Common/StringExtensions.cs index 5867811..1ca906a 100644 --- a/src/GeneratedEndpoints/Common/StringExtensions.cs +++ b/src/GeneratedEndpoints/Common/StringExtensions.cs @@ -1,4 +1,5 @@ using System.Globalization; +using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints.Common; @@ -75,4 +76,12 @@ public static string NormalizeOrDefaultString(this string? value, string default { return string.IsNullOrWhiteSpace(value) ? defaultValue : value!.Trim(); } + + public static string RemoveAsyncSuffix(this string methodName) + { + if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) + return methodName[..^AsyncSuffix.Length]; + + return methodName; + } } diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 00181d4..dde1428 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -1,5 +1,4 @@ using System.Collections.Immutable; -using System.Runtime.CompilerServices; using GeneratedEndpoints.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -14,8 +13,6 @@ namespace GeneratedEndpoints; [Generator] public sealed class MinimalApiGenerator : IIncrementalGenerator { - private static readonly ConditionalWeakTable RequestHandlerClassCache = new(); - public void Initialize(IncrementalGeneratorInitializationContext context) { context.RegisterPostInitializationOutput(RegisterAttributes); @@ -33,27 +30,29 @@ public void Initialize(IncrementalGeneratorInitializationContext context) requestHandlerProviders.Add(handlers); } - var requestHandlers = CombineRequestHandlers(requestHandlerProviders.MoveToImmutable()) - .Select((x, _) => x.ToEquatableImmutableArray()); + var requestHandlers = CombineRequestHandlers(requestHandlerProviders); context.RegisterSourceOutput(requestHandlers, GenerateSource); } - private static IncrementalValueProvider> CombineRequestHandlers( - ImmutableArray>> handlerProviders + private static IncrementalValueProvider> CombineRequestHandlers( + ImmutableArray>>.Builder builder ) { - if (handlerProviders.IsDefaultOrEmpty) + var handlerProvidersArray = builder.MoveToImmutable(); + + if (handlerProvidersArray.IsDefaultOrEmpty) throw new InvalidOperationException("No HTTP attribute definitions were provided."); - var combined = handlerProviders[0]; - for (var i = 1; i < handlerProviders.Length; i++) + var combined = handlerProvidersArray[0]; + for (var i = 1; i < handlerProvidersArray.Length; i++) { - combined = combined.Combine(handlerProviders[i]) + combined = combined.Combine(handlerProvidersArray[i]) .Select(static (x, _) => x.Left.AddRange(x.Right)); } - return combined; + return combined + .Select((x, _) => x.ToEquatableImmutableArray()); } private static void RegisterAttributes(IncrementalGeneratorPostInitializationContext context) @@ -95,11 +94,11 @@ private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToke return null; var attribute = context.Attributes[0]; - var requestHandlerClass = GetRequestHandlerClass(methodSymbol, cancellationToken); + var requestHandlerClass = RequestHandlerClassHelper.Create(methodSymbol, cancellationToken); if (requestHandlerClass is null) return null; - var requestHandlerMethod = GetRequestHandlerMethod(methodSymbol, cancellationToken); + var requestHandlerMethod = RequestHandlerMethodHelper.Create(methodSymbol, cancellationToken); var (httpMethod, pattern, name) = GetRequestHandlerAttribute(methodSymbol, attribute, cancellationToken); @@ -127,54 +126,22 @@ CancellationToken cancellationToken var httpMethod = HttpAttributeDefinitionsByName.TryGetValue(attributeName, out var definition) ? definition.Verb : ""; var pattern = attribute.GetConstructorStringValue() ?? ""; var name = attribute.GetNamedStringValue(NameAttributeNamedParameter); - name ??= RemoveAsyncSuffix(methodSymbol.Name); + name ??= methodSymbol.Name.RemoveAsyncSuffix(); return (httpMethod, pattern, name); } - private static string RemoveAsyncSuffix(string methodName) - { - if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) - return methodName[..^AsyncSuffix.Length]; - - return methodName; - } - - private static RequestHandlerClass? GetRequestHandlerClass(IMethodSymbol methodSymbol, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - var classSymbol = methodSymbol.ContainingType; - if (classSymbol.TypeKind != TypeKind.Class) - return null; - - var cacheEntry = RequestHandlerClassCache.GetValue(classSymbol, static _ => new RequestHandlerClassCacheEntry()); - var requestHandlerClass = cacheEntry.GetOrCreate(classSymbol, cancellationToken); - - return requestHandlerClass; - } - - private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol methodSymbol, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - var name = methodSymbol.Name; - var isStatic = methodSymbol.IsStatic; - var parameters = methodSymbol.GetParameters(cancellationToken); - var configuration = EndpointConfigurationFactory.Create(methodSymbol); - var requestHandlerMethod = new RequestHandlerMethod(name, isStatic, parameters, configuration); - - return requestHandlerMethod; - } - private static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) { context.CancellationToken.ThrowIfCancellationRequested(); var normalized = NormalizeRequestHandlers(requestHandlers); - AddEndpointHandlersGenerator.GenerateSource(context, normalized); - UseEndpointHandlersGenerator.GenerateSource(context, normalized); + var grouped = normalized.GroupBy(x => x.Class).OrderBy(x => x.Key) + .ToImmutableSortedDictionary(x => x.Key, x => x.OrderBy(y => y.Method).ToImmutableArray()); + + AddEndpointHandlersGenerator.GenerateSource(context, grouped); + UseEndpointHandlersGenerator.GenerateSource(context, grouped); } private static EquatableImmutableArray NormalizeRequestHandlers(EquatableImmutableArray requestHandlers) @@ -182,8 +149,10 @@ private static EquatableImmutableArray NormalizeRequestHandlers( if (requestHandlers.Count <= 1) return requestHandlers; - requestHandlers.SortInPlace(RequestHandlerComparer.Instance); ResolveEndpointNameCollisions(requestHandlers); +#pragma warning disable S125 + //requestHandlers.SortInPlace(RequestHandlerComparer.Instance); +#pragma warning restore S125 return requestHandlers; } diff --git a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs index 93ba94e..4c9f37b 100644 --- a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs +++ b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs @@ -13,11 +13,13 @@ namespace GeneratedEndpoints; internal static class UseEndpointHandlersGenerator { - public static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) + public static void GenerateSource(SourceProductionContext context, ImmutableSortedDictionary> grouped) { context.CancellationToken.ThrowIfCancellationRequested(); - var source = GetUseEndpointHandlersStringBuilder(requestHandlers); + var requestHandlers = grouped.Values.SelectMany(x => x).ToImmutableList(); + + var source = new StringBuilder(); source.AppendLine(FileHeader); source.AppendLine(); @@ -26,7 +28,7 @@ public static void GenerateSource(SourceProductionContext context, EquatableImmu source.AppendLine("using Microsoft.AspNetCore.Http;"); source.AppendLine("using Microsoft.AspNetCore.Mvc;"); source.AppendLine("using Microsoft.AspNetCore.Routing;"); - if (HasRateLimitedHandlers(requestHandlers)) + if (AddUsingRateLimiting(requestHandlers)) source.AppendLine("using Microsoft.AspNetCore.RateLimiting;"); source.AppendLine("using Microsoft.Extensions.DependencyInjection;"); source.AppendLine(); @@ -49,17 +51,23 @@ public static void GenerateSource(SourceProductionContext context, EquatableImmu source.AppendLine(" {"); - var groupedClasses = GetClassesWithMapGroups(requestHandlers); + var groupedClasses = GetClassesWithGroups(requestHandlers); for (var index = 0; index < groupedClasses.Count; index++) { var groupedClass = groupedClasses[index]; + var configuration = groupedClass.Configuration; + if (!configuration.Group.HasValue) + continue; + + var group = configuration.Group.Value; + source.Append(" var "); - source.Append(groupedClass.Configuration.GroupIdentifier); + source.Append(group.Identifier); source.Append(" = builder.MapGroup("); - source.Append(groupedClass.Configuration.GroupPattern!.ToStringLiteral()); + source.Append(group.Pattern.ToStringLiteral()); source.Append(')'); - AppendEndpointConfiguration(source, " ", groupedClass.Configuration); + AppendEndpointConfiguration(source, " ", configuration); source.AppendLine(";"); } @@ -87,7 +95,7 @@ public static void GenerateSource(SourceProductionContext context, EquatableImmu context.AddSource(UseEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); } - private static bool HasRateLimitedHandlers(EquatableImmutableArray requestHandlers) + private static bool AddUsingRateLimiting(ImmutableList requestHandlers) { for (var index = 0; index < requestHandlers.Count; index++) { @@ -99,25 +107,29 @@ private static bool HasRateLimitedHandlers(EquatableImmutableArray GetClassesWithMapGroups(EquatableImmutableArray requestHandlers) + private static List GetClassesWithGroups(ImmutableList requestHandlers) { - var groupedClasses = new List(); if (requestHandlers.Count == 0) - return groupedClasses; + return []; - var seen = new HashSet(StringComparer.Ordinal); + HashSet? seen = null; + List? groupedClasses = null; for (var index = 0; index < requestHandlers.Count; index++) { var handler = requestHandlers[index]; var handlerClass = handler.Class; - if (handlerClass.Configuration.GroupPattern is null) + if (!handlerClass.Configuration.Group.HasValue) + continue; + + seen ??= new HashSet(StringComparer.Ordinal); + if (!seen.Add(handlerClass.Name)) continue; - if (seen.Add(handlerClass.Name)) - groupedClasses.Add(handlerClass); + groupedClasses ??= []; + groupedClasses.Add(handlerClass); } - return groupedClasses; + return groupedClasses ?? []; } private static void GenerateMapRequestHandler(StringBuilder source, RequestHandler requestHandler) @@ -126,7 +138,7 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl var configureAcceptsServiceProvider = requestHandler.Class.ConfigureMethodAcceptsServiceProvider; var indent = wrapWithConfigure ? " " : " "; var continuationIndent = indent + " "; - var routeBuilderIdentifier = requestHandler.Class.Configuration.GroupIdentifier ?? "builder"; + var routeBuilderIdentifier = requestHandler.Class.Configuration.Group?.Identifier ?? "builder"; if (wrapWithConfigure) { @@ -200,7 +212,7 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl source.Append(')'); var configuration = requestHandler.Method.Configuration; - if (requestHandler.Class.Configuration.GroupPattern is null) + if (!requestHandler.Class.Configuration.Group.HasValue) configuration = MergeEndpointConfigurations(requestHandler.Class.Configuration, configuration); if (!string.IsNullOrEmpty(requestHandler.Name)) @@ -261,12 +273,12 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.Append(')'); } - if (!string.IsNullOrEmpty(configuration.GroupName)) + if (configuration.Group is { Name.Length: > 0 }) { source.AppendLine(); source.Append(indent); source.Append(".WithGroupName("); - source.Append(configuration.GroupName.ToStringLiteral()); + source.Append(configuration.Group.Value.Name.ToStringLiteral()); source.Append(')'); } @@ -493,9 +505,9 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu var disableRequestTimeout = methodConfiguration.DisableRequestTimeout || classConfiguration.DisableRequestTimeout; var withRequestTimeout = methodConfiguration.WithRequestTimeout || classConfiguration.WithRequestTimeout; - var groupIdentifier = methodConfiguration.GroupIdentifier ?? classConfiguration.GroupIdentifier; - var groupPattern = methodConfiguration.GroupPattern ?? classConfiguration.GroupPattern; - var groupName = methodConfiguration.GroupName ?? classConfiguration.GroupName; + var groupIdentifier = methodConfiguration.Group?.Identifier ?? classConfiguration.Group?.Identifier; + var groupPattern = methodConfiguration.Group?.Pattern ?? classConfiguration.Group?.Pattern; + var groupName = methodConfiguration.Group?.Name ?? classConfiguration.Group?.Name; string? requestTimeoutPolicyName = null; if (methodConfiguration.WithRequestTimeout) @@ -538,9 +550,14 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu WithRequestTimeout = withRequestTimeout, RequestTimeoutPolicyName = requestTimeoutPolicyName, Order = order, - GroupIdentifier = groupIdentifier, - GroupPattern = groupPattern, - GroupName = groupName, + Group = groupIdentifier is not null && groupPattern is not null + ? new EndpointGroup + { + Identifier = groupIdentifier, + Pattern = groupPattern, + Name = groupName, + } + : null, }; } @@ -611,21 +628,6 @@ private static EquatableImmutableArray MergeUnion(EquatableImmutableArra }; } - private static StringBuilder GetUseEndpointHandlersStringBuilder(EquatableImmutableArray requestHandlers) - { - const int baseSize = 4096; - const int perHandler = 512; - - var handlerCount = Math.Max(requestHandlers.Count, 0); - var estimate = baseSize + (long)perHandler * handlerCount; - estimate = (long)(estimate * 1.10); - - if (estimate > int.MaxValue) - estimate = int.MaxValue; - - return StringBuilderPool.Get((int)Math.Max(baseSize, estimate)); - } - private static void AppendAdditionalContentTypes(StringBuilder source, EquatableImmutableArray? additionalContentTypes) { if (additionalContentTypes is not { Count: > 0 }) From 15a2c0e35cf87b734f08cb2257719ef79b1ffca4 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 13:19:56 -0500 Subject: [PATCH 2/6] Refactor. --- src/GeneratedEndpoints/Common/Constants.cs | 16 +- .../Common/MethodSymbolExtensions.cs | 4 + .../Common/NamedTypeSymbolExtensions.cs | 268 +++++++++++++++--- .../Common/RequestHandlerClassHelper.cs | 60 +--- .../Common/TypeSymbolExtensions.cs | 34 ++- src/GeneratedEndpoints/MinimalApiGenerator.cs | 96 +++---- 6 files changed, 314 insertions(+), 164 deletions(-) diff --git a/src/GeneratedEndpoints/Common/Constants.cs b/src/GeneratedEndpoints/Common/Constants.cs index 8ef77d1..60f2ee2 100644 --- a/src/GeneratedEndpoints/Common/Constants.cs +++ b/src/GeneratedEndpoints/Common/Constants.cs @@ -47,12 +47,6 @@ internal static partial class Constants internal const string SummaryAttributeName = "SummaryAttribute"; internal const string SummaryAttributeHint = $"{SummaryAttributeFullyQualifiedName}.gs.cs"; - internal const string DisplayNameAttributeName = nameof(DisplayNameAttribute); - internal const string DescriptionAttributeName = nameof(DescriptionAttribute); - internal const string AllowAnonymousAttributeName = "AllowAnonymousAttribute"; - internal const string TagsAttributeName = "TagsAttribute"; - internal const string ExcludeFromDescriptionAttributeName = "ExcludeFromDescriptionAttribute"; - internal const string EndpointFilterAttributeName = "EndpointFilterAttribute"; internal const string EndpointFilterAttributeHint = $"{EndpointFilterAttributeFullyQualifiedName}.gs.cs"; @@ -82,15 +76,7 @@ internal static partial class Constants internal const string AsyncSuffix = "Async"; internal const string ApplicationJsonContentType = "application/json"; internal const string GlobalPrefix = "global::"; - internal const string Dot = "."; - - internal static readonly string[] AttributesNamespaceParts = AttributesNamespace.Split('.'); - internal static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"]; - internal static readonly string[] AspNetCoreMvcNamespaceParts = ["Microsoft", "AspNetCore", "Mvc"]; - internal static readonly string[] AspNetCoreAuthorizationNamespaceParts = ["Microsoft", "AspNetCore", "Authorization"]; - internal static readonly string[] AspNetCoreRoutingNamespaceParts = ["Microsoft", "AspNetCore", "Routing"]; - internal static readonly string[] ExtensionsDependencyInjectionNamespaceParts = ["Microsoft", "Extensions", "DependencyInjection"]; - internal static readonly string[] ComponentModelNamespaceParts = ["System", "ComponentModel"]; + private const string BaseNamespace = "Microsoft.AspNetCore.Generated"; private const string AttributesNamespace = $"{BaseNamespace}.Attributes"; private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}"; diff --git a/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs b/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs index bb4cdfc..c5261e1 100644 --- a/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs +++ b/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs @@ -69,6 +69,10 @@ private static string GetBindingPrefix(IParameterSymbol parameter) return bindingPrefix; } + internal static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"]; + internal static readonly string[] AspNetCoreMvcNamespaceParts = ["Microsoft", "AspNetCore", "Mvc"]; + internal static readonly string[] ExtensionsDependencyInjectionNamespaceParts = ["Microsoft", "Extensions", "DependencyInjection"]; + private static BindingSource GetBindingSourceFromAttributeClass(INamedTypeSymbol attributeClass) { var definition = attributeClass.OriginalDefinition; diff --git a/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs b/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs index 36c1bca..187b1e9 100644 --- a/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs +++ b/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs @@ -6,45 +6,239 @@ namespace GeneratedEndpoints.Common; /// Provides extension methods for working with named type symbols. internal static class NamedTypeSymbolExtensions { - public static RequestHandlerAttributeKind GetRequestHandlerAttributeKind(this INamedTypeSymbol definition) + public static RequestHandlerAttributeKind GetRequestHandlerAttributeKind(this ITypeSymbol definition) { - if (AttributeSymbolMatcher.IsAttribute(definition, DisplayNameAttributeName, ComponentModelNamespaceParts)) - return RequestHandlerAttributeKind.DisplayName; - - if (AttributeSymbolMatcher.IsAttribute(definition, DescriptionAttributeName, ComponentModelNamespaceParts)) - return RequestHandlerAttributeKind.Description; - - if (AttributeSymbolMatcher.IsAttribute(definition, AllowAnonymousAttributeName, AspNetCoreAuthorizationNamespaceParts)) - return RequestHandlerAttributeKind.AllowAnonymous; - - if (AttributeSymbolMatcher.IsAttribute(definition, TagsAttributeName, AspNetCoreHttpNamespaceParts)) - return RequestHandlerAttributeKind.Tags; - - if (AttributeSymbolMatcher.IsAttribute(definition, ExcludeFromDescriptionAttributeName, AspNetCoreRoutingNamespaceParts)) - return RequestHandlerAttributeKind.ExcludeFromDescription; - - if (!AttributeSymbolMatcher.IsInNamespace(definition.ContainingNamespace, AttributesNamespaceParts)) - return RequestHandlerAttributeKind.None; - - return definition.Name switch + return definition switch { - ShortCircuitAttributeName => RequestHandlerAttributeKind.ShortCircuit, - DisableValidationAttributeName => RequestHandlerAttributeKind.DisableValidation, - DisableRequestTimeoutAttributeName => RequestHandlerAttributeKind.DisableRequestTimeout, - RequestTimeoutAttributeName => RequestHandlerAttributeKind.RequestTimeout, - OrderAttributeName => RequestHandlerAttributeKind.Order, - MapGroupAttributeName => RequestHandlerAttributeKind.MapGroup, - SummaryAttributeName => RequestHandlerAttributeKind.Summary, - AcceptsAttributeName => RequestHandlerAttributeKind.Accepts, - ProducesResponseAttributeName => RequestHandlerAttributeKind.ProducesResponse, - RequireAuthorizationAttributeName => RequestHandlerAttributeKind.RequireAuthorization, - RequireCorsAttributeName => RequestHandlerAttributeKind.RequireCors, - RequireHostAttributeName => RequestHandlerAttributeKind.RequireHost, - RequireRateLimitingAttributeName => RequestHandlerAttributeKind.RequireRateLimiting, - EndpointFilterAttributeName => RequestHandlerAttributeKind.EndpointFilter, - DisableAntiforgeryAttributeName => RequestHandlerAttributeKind.DisableAntiforgery, - ProducesProblemAttributeName => RequestHandlerAttributeKind.ProducesProblem, - ProducesValidationProblemAttributeName => RequestHandlerAttributeKind.ProducesValidationProblem, + { + MetadataName: "DisplayNameAttribute", + ContainingNamespace: + { + Name: "ComponentModel", + ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true }, + }, + } => RequestHandlerAttributeKind.DisplayName, + { + MetadataName: "DescriptionAttribute", + ContainingNamespace: + { + Name: "ComponentModel", + ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true }, + }, + } => RequestHandlerAttributeKind.Description, + { + MetadataName: "AllowAnonymousAttribute", + ContainingNamespace: + { + Name: "Authorization", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}, + }, + } => RequestHandlerAttributeKind.AllowAnonymous, + { + MetadataName: "TagsAttribute", + ContainingNamespace: + { + Name: "Http", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}, + }, + } => RequestHandlerAttributeKind.Tags, + { + MetadataName: "ExcludeFromDescriptionAttribute", + ContainingNamespace: + { + Name: "Routing", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}, + }, + } => RequestHandlerAttributeKind.ExcludeFromDescription, + { + MetadataName: "ExcludeFromDescriptionAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.ExcludeFromDescription, + { + MetadataName: "ShortCircuitAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.ShortCircuit, + { + MetadataName: "DisableValidationAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.DisableValidation, + { + MetadataName: "DisableRequestTimeoutAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.DisableRequestTimeout, + { + MetadataName: "RequestTimeoutAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.RequestTimeout, + { + MetadataName: "OrderAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.Order, + { + MetadataName: "MapGroupAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.MapGroup, + { + MetadataName: "SummaryAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.Summary, + { + MetadataName: "AcceptsAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.Accepts, + { + MetadataName: "AcceptsAttribute`1", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.Accepts, + { + MetadataName: "ProducesResponseAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.ProducesResponse, + { + MetadataName: "ProducesResponseAttribute`1", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.ProducesResponse, + { + MetadataName: "RequireAuthorizationAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.RequireAuthorization, + { + MetadataName: "RequireCorsAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.RequireCors, + { + MetadataName: "RequireHostAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.RequireHost, + { + MetadataName: "RequireRateLimitingAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.RequireRateLimiting, + { + MetadataName: "EndpointFilterAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.EndpointFilter, + { + MetadataName: "EndpointFilterAttribute`1", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.EndpointFilter, + { + MetadataName: "DisableAntiforgeryAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.DisableAntiforgery, + { + MetadataName: "ProducesProblemAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.ProducesProblem, + { + MetadataName: "ProducesValidationProblemAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: { Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, + }, + } => RequestHandlerAttributeKind.ProducesValidationProblem, _ => RequestHandlerAttributeKind.None, }; } diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassHelper.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassHelper.cs index 74df35c..806a254 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassHelper.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassHelper.cs @@ -1,5 +1,4 @@ using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints.Common; @@ -80,7 +79,7 @@ private static bool IsConfigureMethod(IMethodSymbol methodSymbol, out bool accep if (methodSymbol.Parameters.Length == 2) { var serviceProviderParameter = methodSymbol.Parameters[1]; - if (!IsServiceProviderParameter(serviceProviderParameter.Type)) + if (!serviceProviderParameter.Type.IsIServiceProvider()) return false; acceptsServiceProvider = true; @@ -89,64 +88,9 @@ private static bool IsConfigureMethod(IMethodSymbol methodSymbol, out bool accep if (!methodSymbol.ReturnsVoid) return false; - if (!HasEndpointConventionBuilderConstraint(builderTypeParameter, methodSymbol)) + if (!builderTypeParameter.ConstraintTypes.Any(x => x.IsIEndpointConventionBuilder())) return false; return true; } - - private static bool IsServiceProviderParameter(ITypeSymbol typeSymbol) - { - return MatchesServiceProvider(typeSymbol); - } - - private static bool HasEndpointConventionBuilderConstraint(ITypeParameterSymbol builderTypeParameter, IMethodSymbol methodSymbol) - { - var symbolMatches = builderTypeParameter.ConstraintTypes.Any(MatchesEndpointConventionBuilder); - if (symbolMatches) - return true; - - return methodSymbol.DeclaringSyntaxReferences - .Select(reference => reference.GetSyntax()) - .OfType() - .SelectMany(methodSyntax => methodSyntax.ConstraintClauses) - .Where(clause => string.Equals(clause.Name.Identifier.ValueText, builderTypeParameter.Name, StringComparison.Ordinal)) - .SelectMany(clause => clause.Constraints.OfType()) - .Any(constraint => IsEndpointConventionBuilderIdentifier(constraint.Type)); - } - - private static bool IsEndpointConventionBuilderIdentifier(TypeSyntax typeSyntax) - { - return typeSyntax switch - { - QualifiedNameSyntax qualified => IsEndpointConventionBuilderIdentifier(qualified.Right), - AliasQualifiedNameSyntax alias => IsEndpointConventionBuilderIdentifier(alias.Name), - SimpleNameSyntax simple => string.Equals(simple.Identifier.ValueText, "IEndpointConventionBuilder", StringComparison.Ordinal), - _ => false, - }; - } - - private static bool MatchesEndpointConventionBuilder(ITypeSymbol typeSymbol) - { - if (typeSymbol is not INamedTypeSymbol namedType) - return false; - - if (!string.Equals(namedType.Name, "IEndpointConventionBuilder", StringComparison.Ordinal)) - return false; - - var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? ""; - return string.Equals(containingNamespace, "Microsoft.AspNetCore.Builder", StringComparison.Ordinal); - } - - private static bool MatchesServiceProvider(ITypeSymbol typeSymbol) - { - if (typeSymbol is not INamedTypeSymbol namedType) - return false; - - if (!string.Equals(namedType.Name, "IServiceProvider", StringComparison.Ordinal)) - return false; - - var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? ""; - return string.Equals(containingNamespace, "System", StringComparison.Ordinal); - } } diff --git a/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs b/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs index 0b69f45..7ad9d5e 100644 --- a/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs +++ b/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs @@ -4,12 +4,36 @@ namespace GeneratedEndpoints.Common; internal static class TypeSymbolExtensions { + public static bool IsIEndpointConventionBuilder(this ITypeSymbol symbol) + { + return symbol is + { + MetadataName: "IEndpointConventionBuilder", + ContainingNamespace: + { + Name: "Builder", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }; + } + + public static bool IsIServiceProvider(this ITypeSymbol symbol) + { + return symbol is + { + MetadataName: "IServiceProvider", + ContainingNamespace: + { + Name: "System", ContainingNamespace.IsGlobalNamespace: true, + }, + }; + } + public static bool IsAwaitable(this ITypeSymbol symbol) { return symbol switch { - INamedTypeSymbol - { + { MetadataName: "ValueTask`1", ContainingNamespace: { @@ -17,7 +41,7 @@ public static bool IsAwaitable(this ITypeSymbol symbol) ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, }, } - or INamedTypeSymbol + or { MetadataName: "Task`1", ContainingNamespace: @@ -26,7 +50,7 @@ or INamedTypeSymbol ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, }, } - or INamedTypeSymbol + or { MetadataName: "ValueTask", ContainingNamespace: @@ -35,7 +59,7 @@ or INamedTypeSymbol ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, }, } - or INamedTypeSymbol + or { MetadataName: "Task", ContainingNamespace: diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index dde1428..873acfd 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -17,44 +17,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) { context.RegisterPostInitializationOutput(RegisterAttributes); - var requestHandlerProviders = ImmutableArray.CreateBuilder>>(HttpAttributeDefinitions.Length); - - for (var index = 0; index < HttpAttributeDefinitions.Length; index++) - { - var definition = HttpAttributeDefinitions[index]; - var handlers = context.SyntaxProvider - .ForAttributeWithMetadataName(definition.FullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - requestHandlerProviders.Add(handlers); - } - - var requestHandlers = CombineRequestHandlers(requestHandlerProviders); + var requestHandlers = GetRequestHandlers(context); context.RegisterSourceOutput(requestHandlers, GenerateSource); } - private static IncrementalValueProvider> CombineRequestHandlers( - ImmutableArray>>.Builder builder - ) - { - var handlerProvidersArray = builder.MoveToImmutable(); - - if (handlerProvidersArray.IsDefaultOrEmpty) - throw new InvalidOperationException("No HTTP attribute definitions were provided."); - - var combined = handlerProvidersArray[0]; - for (var i = 1; i < handlerProvidersArray.Length; i++) - { - combined = combined.Combine(handlerProvidersArray[i]) - .Select(static (x, _) => x.Left.AddRange(x.Right)); - } - - return combined - .Select((x, _) => x.ToEquatableImmutableArray()); - } - private static void RegisterAttributes(IncrementalGeneratorPostInitializationContext context) { foreach (var definition in HttpAttributeDefinitions) @@ -79,6 +46,44 @@ private static void RegisterAttributes(IncrementalGeneratorPostInitializationCon context.AddSource(SummaryAttributeHint, SummaryAttributeSourceText); } + private static IncrementalValueProvider> GetRequestHandlers(IncrementalGeneratorInitializationContext context) + { + var list = new List>>(HttpAttributeDefinitions.Length); + + for (var index = 0; index < HttpAttributeDefinitions.Length; index++) + { + var definition = HttpAttributeDefinitions[index]; + var handlers = context.SyntaxProvider + .ForAttributeWithMetadataName(definition.FullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) + .WhereNotNull() + .Collect(); + + list.Add(handlers); + } + + if (list.Count == 0) + throw new InvalidOperationException("No HTTP attribute definitions were provided."); + + var combined = list[0]; + for (var i = 1; i < list.Count; i++) + { + combined = combined.Combine(list[i]) + .Select(static (x, ct) => + { + ct.ThrowIfCancellationRequested(); + return x.Left.AddRange(x.Right); + } + ); + } + + return combined.Select((x, ct) => + { + ct.ThrowIfCancellationRequested(); + return x.ToEquatableImmutableArray(); + } + ); + } + private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -135,30 +140,23 @@ private static void GenerateSource(SourceProductionContext context, EquatableImm { context.CancellationToken.ThrowIfCancellationRequested(); - var normalized = NormalizeRequestHandlers(requestHandlers); + ResolveEndpointNameCollisions(requestHandlers); - var grouped = normalized.GroupBy(x => x.Class).OrderBy(x => x.Key) - .ToImmutableSortedDictionary(x => x.Key, x => x.OrderBy(y => y.Method).ToImmutableArray()); + var grouped = requestHandlers.GroupBy(x => x.Class) + .OrderBy(x => x.Key) + .ToImmutableSortedDictionary(x => x.Key, x => x.OrderBy(y => y.Method) + .ToImmutableArray() + ); AddEndpointHandlersGenerator.GenerateSource(context, grouped); UseEndpointHandlersGenerator.GenerateSource(context, grouped); } - private static EquatableImmutableArray NormalizeRequestHandlers(EquatableImmutableArray requestHandlers) + private static void ResolveEndpointNameCollisions(EquatableImmutableArray requestHandlers) { if (requestHandlers.Count <= 1) - return requestHandlers; - - ResolveEndpointNameCollisions(requestHandlers); -#pragma warning disable S125 - //requestHandlers.SortInPlace(RequestHandlerComparer.Instance); -#pragma warning restore S125 - - return requestHandlers; - } + return; - private static void ResolveEndpointNameCollisions(EquatableImmutableArray requestHandlers) - { var raw = requestHandlers.AsArray(); if (raw is null) return; From 8c86ef4af2b2c3df301800aa0925976cccf0883d Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 13:29:36 -0500 Subject: [PATCH 3/6] Add tests for generic ProducesResponse attribute (#72) --- .../Common/SourceFactory.cs | 4 +++ ...3EC6390D4_AddEndpointHandlers.verified.txt | 24 +++++++++++++ ...3EC6390D4_MapEndpointHandlers.verified.txt | 35 ++++++++++++++++++ ...5DF1A025B_AddEndpointHandlers.verified.txt | 24 +++++++++++++ ...5DF1A025B_MapEndpointHandlers.verified.txt | 35 ++++++++++++++++++ ...FE6E1F139_MapEndpointHandlers.verified.txt | 1 + ...5BF1E7B54_AddEndpointHandlers.verified.txt | 24 +++++++++++++ ...5BF1E7B54_MapEndpointHandlers.verified.txt | 36 +++++++++++++++++++ .../GeneratedSourceTests.cs | 13 +++---- ...Attribute_AddEndpointHandlers.verified.txt | 24 +++++++++++++ ...Attribute_MapEndpointHandlers.verified.txt | 31 ++++++++++++++++ .../IndividualTests.cs | 10 +++++- 12 files changed, 254 insertions(+), 7 deletions(-) create mode 100644 tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_AddEndpointHandlers.verified.txt create mode 100644 tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_MapEndpointHandlers.verified.txt create mode 100644 tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_AddEndpointHandlers.verified.txt create mode 100644 tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_MapEndpointHandlers.verified.txt create mode 100644 tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_AddEndpointHandlers.verified.txt create mode 100644 tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_MapEndpointHandlers.verified.txt create mode 100644 tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_AddEndpointHandlers.verified.txt create mode 100644 tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_MapEndpointHandlers.verified.txt diff --git a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs index 4fef5bd..1989896 100644 --- a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs +++ b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs @@ -322,6 +322,7 @@ public static string BuildContractsAndBindingSource( bool includeAccepts, bool includeGenericAccepts, bool includeProducesResponse, + bool includeGenericProducesResponse, bool includeProducesProblem, bool includeProducesValidationProblem, bool includeSummaryAndDescription, @@ -383,6 +384,9 @@ public static string BuildContractsAndBindingSource( ); } + if (includeGenericProducesResponse) + builder.AppendLine($" [ProducesResponse(200, \"{producesContentType1 ?? "application/json"}\")]"); + if (includeProducesProblem) builder.AppendLine($" [ProducesProblem(500, \"{producesContentType1 ?? "application/problem+json"}\")]"); diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_AddEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_AddEndpointHandlers.verified.txt new file mode 100644 index 0000000..7a6643b --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_AddEndpointHandlers.verified.txt @@ -0,0 +1,24 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointServicesExtensions +{ + internal static void AddEndpointHandlers(this IServiceCollection services) + { + services.TryAddScoped(); + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_MapEndpointHandlers.verified.txt new file mode 100644 index 0000000..c9e81bb --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_MapEndpointHandlers.verified.txt @@ -0,0 +1,35 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointRouteBuilderExtensions +{ + internal static IEndpointRouteBuilder MapEndpointHandlers(this IEndpointRouteBuilder builder) + { + builder.MapGet("/contracts/{id:int}", global::ContractEndpoints.Handle) + .WithName("Handle") + .WithDisplayName("Contract endpoint") + .WithTags("Contracts", "Bindings") + .Produces(200, "application/problem+json") + .ProducesProblem(500, "application/problem+json") + .AllowAnonymous(); + + return builder; + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_AddEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_AddEndpointHandlers.verified.txt new file mode 100644 index 0000000..7a6643b --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_AddEndpointHandlers.verified.txt @@ -0,0 +1,24 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointServicesExtensions +{ + internal static void AddEndpointHandlers(this IServiceCollection services) + { + services.TryAddScoped(); + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_MapEndpointHandlers.verified.txt new file mode 100644 index 0000000..d535349 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_MapEndpointHandlers.verified.txt @@ -0,0 +1,35 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointRouteBuilderExtensions +{ + internal static IEndpointRouteBuilder MapEndpointHandlers(this IEndpointRouteBuilder builder) + { + builder.MapGet("/contracts/{id:int}", global::ContractEndpoints.Handle) + .WithName("Handle") + .WithDisplayName("Contract endpoint") + .ExcludeFromDescription() + .Produces(200, "application/json") + .ProducesProblem(500, "application/json") + .RequireAuthorization("ContractsPolicy"); + + return builder; + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_9F5FE6E1F139_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_9F5FE6E1F139_MapEndpointHandlers.verified.txt index 70f58ad..74ba401 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_9F5FE6E1F139_MapEndpointHandlers.verified.txt +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_9F5FE6E1F139_MapEndpointHandlers.verified.txt @@ -30,6 +30,7 @@ internal static class EndpointRouteBuilderExtensions .WithTags("Contracts", "Bindings") .Accepts("application/xml") .Produces(200, "application/json", "text/json") + .Produces(200, "application/json") .ProducesProblem(500, "application/json") .ProducesValidationProblem(422, "application/json") .AllowAnonymous(); diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_AddEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_AddEndpointHandlers.verified.txt new file mode 100644 index 0000000..6e79c25 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_AddEndpointHandlers.verified.txt @@ -0,0 +1,24 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointServicesExtensions +{ + internal static void AddEndpointHandlers(this IServiceCollection services) + { + services.TryAddScoped(); + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_MapEndpointHandlers.verified.txt new file mode 100644 index 0000000..a5b8b0d --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_MapEndpointHandlers.verified.txt @@ -0,0 +1,36 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointRouteBuilderExtensions +{ + internal static IEndpointRouteBuilder MapEndpointHandlers(this IEndpointRouteBuilder builder) + { + builder.MapGet("/contracts/{id:int}", global::GeneratedEndpointsTests.ContractEndpoints.Handle) + .WithName("Handle") + .WithDisplayName("Contract endpoint") + .WithTags("Contracts", "Bindings") + .Accepts("application/json") + .Produces(200, "application/json") + .ProducesValidationProblem(422, "application/problem+json") + .AllowAnonymous(); + + return builder; + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs index dd38cbb..ebf6aa6 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs @@ -171,17 +171,17 @@ await result.VerifyAsync("MapEndpointHandlers.g.cs") } [Theory] - [InlineData(true, true, true, true, true, true, true, true, true, true, true, true, true, false, true, false, "application/xml", "text/xml", + [InlineData(true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, true, false, "application/xml", "text/xml", "application/json", "text/json" )] - [InlineData(false, false, true, false, false, true, false, true, true, false, false, false, true, true, false, true, "application/custom", null, + [InlineData(false, false, true, false, false, true, false, true, false, true, false, false, true, true, false, true, false, "application/custom", null, "application/problem+json", null )] - [InlineData(true, true, false, true, true, false, true, false, false, true, true, true, false, false, true, true, null, null, null, null)] - [InlineData(false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, "application/xml", null, + [InlineData(true, true, false, true, true, false, true, false, false, false, true, true, true, false, false, true, true, null, null, null, null)] + [InlineData(false, true, false, true, false, true, false, true, false, true, false, false, true, false, true, false, true, "application/xml", null, "application/json", null )] - [InlineData(true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, null, "text/plain", null, "text/plain")] + [InlineData(true, false, true, false, true, false, true, false, true, false, true, false, true, true, false, true, false, null, "text/plain", null, "text/plain")] public async Task ContractsAndBindingMatrix( bool withNamespace, bool includeBindingNames, @@ -191,6 +191,7 @@ public async Task ContractsAndBindingMatrix( bool includeAccepts, bool includeGenericAccepts, bool includeProducesResponse, + bool includeGenericProducesResponse, bool includeProducesProblem, bool includeProducesValidationProblem, bool includeSummaryAndDescription, @@ -206,7 +207,7 @@ public async Task ContractsAndBindingMatrix( ) { var source = SourceFactory.BuildContractsAndBindingSource(includeBindingNames, includeAsParameters, includeFromServices, includeFromKeyedServices, - includeAccepts, includeGenericAccepts, includeProducesResponse, includeProducesProblem, includeProducesValidationProblem, + includeAccepts, includeGenericAccepts, includeProducesResponse, includeGenericProducesResponse, includeProducesProblem, includeProducesValidationProblem, includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, methodRequiresAuthorization, acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 ); diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_AddEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_AddEndpointHandlers.verified.txt new file mode 100644 index 0000000..6e79c25 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_AddEndpointHandlers.verified.txt @@ -0,0 +1,24 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointServicesExtensions +{ + internal static void AddEndpointHandlers(this IServiceCollection services) + { + services.TryAddScoped(); + } +} diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_MapEndpointHandlers.verified.txt new file mode 100644 index 0000000..8a60c59 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_MapEndpointHandlers.verified.txt @@ -0,0 +1,31 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointRouteBuilderExtensions +{ + internal static IEndpointRouteBuilder MapEndpointHandlers(this IEndpointRouteBuilder builder) + { + builder.MapGet("/contracts/{id:int}", global::GeneratedEndpointsTests.ContractEndpoints.Handle) + .WithName("Handle") + .Produces(200, "application/json"); + + return builder; + } +} diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.cs b/tests/GeneratedEndpoints.Tests/IndividualTests.cs index be6ee72..55adbde 100644 --- a/tests/GeneratedEndpoints.Tests/IndividualTests.cs +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.cs @@ -378,6 +378,13 @@ public async Task ProducesResponseAttribute() await VerifyIndividualAsync(source, nameof(ProducesResponseAttribute)); } + [Fact] + public async Task GenericProducesResponseAttribute() + { + var source = ContractScenario(includeGenericProducesResponse: true, producesContentType1: "application/json"); + await VerifyIndividualAsync(source, nameof(GenericProducesResponseAttribute)); + } + [Fact] public async Task ProducesResponseMultipleContentTypes() { @@ -546,6 +553,7 @@ private static string ContractScenario( bool includeAccepts = false, bool includeGenericAccepts = false, bool includeProducesResponse = false, + bool includeGenericProducesResponse = false, bool includeProducesProblem = false, bool includeProducesValidationProblem = false, bool includeSummaryAndDescription = false, @@ -561,7 +569,7 @@ private static string ContractScenario( ) { return SourceFactory.BuildContractsAndBindingSource(includeBindingNames, includeAsParameters, includeFromServices, includeFromKeyedServices, - includeAccepts, includeGenericAccepts, includeProducesResponse, includeProducesProblem, includeProducesValidationProblem, + includeAccepts, includeGenericAccepts, includeProducesResponse, includeGenericProducesResponse, includeProducesProblem, includeProducesValidationProblem, includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, methodRequiresAuthorization, acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 ); From 6c24e6ccba674dbd4dc9c475840d174b5f4f26ed Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:13:44 -0500 Subject: [PATCH 4/6] Refactored. --- .../Common/AttributeDataExtensions.cs | 8 +++ .../Common/Constants.GeneratedSources.cs | 64 ++++++++++--------- src/GeneratedEndpoints/Common/Constants.cs | 4 -- .../Common/EndpointConfigurationFactory.cs | 53 ++++++--------- .../Common/NamedTypeSymbolExtensions.cs | 1 - .../GetUserEndpoint.cs | 4 +- ...erationTests.AcceptsAttribute.verified.txt | 40 ++++++------ ...Tests.EndpointFilterAttribute.verified.txt | 10 +-- ...sts.ProducesResponseAttribute.verified.txt | 14 ++-- .../Common/SourceFactory.cs | 2 +- 10 files changed, 99 insertions(+), 101 deletions(-) diff --git a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs index 1a10e01..9dab472 100644 --- a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs +++ b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs @@ -56,6 +56,14 @@ internal static class AttributeDataExtensions return null; } + public static ITypeSymbol? GetConstructorTypeSymbol(this AttributeData attribute, int position = 0) + { + if (attribute.ConstructorArguments.Length > position && attribute.ConstructorArguments[position].Value is ITypeSymbol typeSymbol) + return typeSymbol; + + return null; + } + public static ITypeSymbol? GetNamedTypeSymbol(this AttributeData attribute, string namedParameter) { foreach (var namedArg in attribute.NamedArguments) diff --git a/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs index 6b6d8e5..1b90733 100644 --- a/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs +++ b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs @@ -370,14 +370,9 @@ namespace {{AttributesNamespace}}; internal sealed class {{AcceptsAttributeName}} : global::System.Attribute { /// - /// Gets the request type accepted by the endpoint. + /// Gets the CLR type of the endpoint filter. /// - public global::System.Type? RequestType { get; init; } - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } + public global::System.Type Type { get; } /// /// Gets the primary content type accepted by the endpoint. @@ -389,15 +384,22 @@ internal sealed class {{AcceptsAttributeName}} : global::System.Attribute /// public string[] AdditionalContentTypes { get; } + /// + /// Gets a value indicating whether the request body is optional. + /// + public bool IsOptional { get; init; } + /// /// Initializes a new instance of the class. /// + /// The CLR type of the request body. /// The primary content type accepted by the endpoint. /// Additional content types accepted by the endpoint. - public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) + public {{AcceptsAttributeName}}(global::System.Type type, string contentType = "application/json", params string[] additionalContentTypes) { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; + Type = type; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes; } } @@ -409,14 +411,9 @@ internal sealed class {{AcceptsAttributeName}} : global::System.Attribute internal sealed class {{AcceptsAttributeName}} : global::System.Attribute { /// - /// Gets the request type accepted by the endpoint. + /// Gets the CLR type of the endpoint filter. /// - public global::System.Type RequestType => typeof(TRequest); - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } + public global::System.Type Type => typeof(TRequest); /// /// Gets the primary content type accepted by the endpoint. @@ -428,6 +425,11 @@ internal sealed class {{AcceptsAttributeName}} : global::System.Attrib /// public string[] AdditionalContentTypes { get; } + /// + /// Gets a value indicating whether the request body is optional. + /// + public bool IsOptional { get; init; } + /// /// Initializes a new instance of the generic Accepts attribute class. /// @@ -435,8 +437,8 @@ internal sealed class {{AcceptsAttributeName}} : global::System.Attrib /// Additional content types accepted by the endpoint. public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes; } } @@ -457,15 +459,15 @@ internal sealed class {{EndpointFilterAttributeName}} : global::System.Attribute /// /// Gets the CLR type of the endpoint filter. /// - public global::System.Type FilterType { get; } + public global::System.Type Type { get; } /// /// Initializes a new instance of the class. /// - /// The CLR type of the endpoint filter. - public {{EndpointFilterAttributeName}}(global::System.Type filterType) + /// The CLR type of the endpoint filter. + public {{EndpointFilterAttributeName}}(global::System.Type type) { - FilterType = filterType ?? throw new global::System.ArgumentNullException(nameof(filterType)); + Type = type; } } @@ -479,7 +481,7 @@ internal sealed class {{EndpointFilterAttributeName}} : global::System. /// /// Gets the CLR type of the endpoint filter. /// - public global::System.Type FilterType => typeof(TFilter); + public global::System.Type Type => typeof(TFilter); } """, Encoding.UTF8 @@ -496,10 +498,10 @@ namespace {{AttributesNamespace}}; [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute { - /// - /// Gets the response type produced by the endpoint. - /// - public global::System.Type? ResponseType { get; init; } + /// + /// Gets the response type produced by the endpoint. + /// + public global::System.Type Type { get; } /// /// Gets the HTTP status code returned by the endpoint. @@ -519,11 +521,13 @@ internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribu /// /// Initializes a new instance of the class. /// + /// The CLR type of the response body. /// The HTTP status code returned by the endpoint. /// The primary content type produced by the endpoint. /// Additional content types produced by the endpoint. - public {{ProducesResponseAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) + public {{ProducesResponseAttributeName}}(global::System.Type type, int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) { + Type = type; StatusCode = statusCode; ContentType = contentType; AdditionalContentTypes = additionalContentTypes ?? []; @@ -540,7 +544,7 @@ internal sealed class {{ProducesResponseAttributeName}} : global::Sys /// /// Gets the response type produced by the endpoint. /// - public global::System.Type ResponseType => typeof(TResponse); + public global::System.Type Type => typeof(TResponse); /// /// Gets the HTTP status code returned by the endpoint. diff --git a/src/GeneratedEndpoints/Common/Constants.cs b/src/GeneratedEndpoints/Common/Constants.cs index 60f2ee2..d7cb149 100644 --- a/src/GeneratedEndpoints/Common/Constants.cs +++ b/src/GeneratedEndpoints/Common/Constants.cs @@ -1,5 +1,3 @@ -using System.ComponentModel; - namespace GeneratedEndpoints.Common; internal static partial class Constants @@ -7,8 +5,6 @@ internal static partial class Constants internal const string FallbackHttpMethod = "__FALLBACK__"; internal const string NameAttributeNamedParameter = "Name"; - internal const string ResponseTypeAttributeNamedParameter = "ResponseType"; - internal const string RequestTypeAttributeNamedParameter = "RequestType"; internal const string IsOptionalAttributeNamedParameter = "IsOptional"; internal const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute"; diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index 9bbcd49..ea80621 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -1,4 +1,3 @@ -using System.Runtime.CompilerServices; using Microsoft.CodeAnalysis; using static GeneratedEndpoints.Common.Constants; @@ -6,8 +5,6 @@ namespace GeneratedEndpoints.Common; internal static class EndpointConfigurationFactory { - private static readonly ConditionalWeakTable GeneratedAttributeKindCache = new(); - public static EndpointConfiguration Create(ISymbol symbol) { var attributes = symbol.GetAttributes(); @@ -48,7 +45,7 @@ public static EndpointConfiguration Create(ISymbol symbol) if (attributeClass is null) continue; - var attributeKind = GetGeneratedAttributeKind(attributeClass); + var attributeKind = attributeClass.OriginalDefinition.GetRequestHandlerAttributeKind(); switch (attributeKind) { case RequestHandlerAttributeKind.ShortCircuit: @@ -203,16 +200,6 @@ public static EndpointConfiguration Create(ISymbol symbol) return StringBuilderPool.ToStringAndReturn(builder); } - private static RequestHandlerAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) - { - var definition = attributeClass.OriginalDefinition; - var cacheEntry = GeneratedAttributeKindCache.GetValue( - definition, static def => new GeneratedAttributeKindCacheEntry(def.GetRequestHandlerAttributeKind()) - ); - - return cacheEntry.Kind; - } - private static EquatableImmutableArray? ToEquatableOrNull(List? values) { return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null; @@ -224,9 +211,10 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) requestType = attributeClass.TypeArguments[0] .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - else if (attribute.GetNamedTypeSymbol(RequestTypeAttributeNamedParameter) is { } requestTypeSymbol) - requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); else + requestType = attribute.GetConstructorTypeSymbol()?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + if (requestType is null) return; var contentType = attribute.GetConstructorStringValue() ?? ApplicationJsonContentType; @@ -241,23 +229,27 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? produces) { - string? responseType; + ProducesMetadata? producesMetadata; if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - responseType = attributeClass.TypeArguments[0] + { + var responseType = attributeClass.TypeArguments[0] .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - else if (attribute.GetNamedTypeSymbol(ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol) - responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var statusCode = attribute.GetConstructorIntValue(position: 0) ?? 200; + var contentType = attribute.GetConstructorStringValue(position: 1); + var additionalContentTypes = attribute.GetConstructorStringArray(2); + producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes); + } else - return; - - var statusCode = attribute.GetConstructorIntValue() ?? 200; - var contentType = attribute.GetConstructorStringValue(1); - var additionalContentTypes = attribute.GetConstructorStringArray(2); - - var producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes); + { + var responseType = attribute.GetConstructorTypeSymbol()?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) ?? ""; + var statusCode = attribute.GetConstructorIntValue(position: 1) ?? 200; + var contentType = attribute.GetConstructorStringValue(position: 2); + var additionalContentTypes = attribute.GetConstructorStringArray(3); + producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes); + } var producesList = produces ??= []; - producesList.Add(producesMetadata); + producesList.Add(producesMetadata.Value); } private static void TryAddEndpointFilter( @@ -296,9 +288,4 @@ private static void TryAddEndpointFilterType(ITypeSymbol? typeSymbol, ref List("application/json", "application/xml", IsOptional = true)] - [ProducesResponse(StatusCodes.Status200OK, "application/json", ResponseType = typeof(UserProfile))] + [ProducesResponse(typeof(UserProfile), StatusCodes.Status200OK, "application/json")] [ProducesResponse(StatusCodes.Status202Accepted, "application/json")] [ProducesProblem(StatusCodes.Status500InternalServerError, "application/problem+json")] [ProducesValidationProblem(StatusCodes.Status400BadRequest, "application/problem+json")] diff --git a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.AcceptsAttribute.verified.txt b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.AcceptsAttribute.verified.txt index d4d0ac4..ed86e13 100644 --- a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.AcceptsAttribute.verified.txt +++ b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.AcceptsAttribute.verified.txt @@ -19,14 +19,9 @@ namespace Microsoft.AspNetCore.Generated.Attributes; internal sealed class AcceptsAttribute : global::System.Attribute { /// - /// Gets the request type accepted by the endpoint. + /// Gets the CLR type of the endpoint filter. /// - public global::System.Type? RequestType { get; init; } - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } + public global::System.Type Type { get; } /// /// Gets the primary content type accepted by the endpoint. @@ -38,15 +33,22 @@ internal sealed class AcceptsAttribute : global::System.Attribute /// public string[] AdditionalContentTypes { get; } + /// + /// Gets a value indicating whether the request body is optional. + /// + public bool IsOptional { get; init; } + /// /// Initializes a new instance of the class. /// + /// The CLR type of the request body. /// The primary content type accepted by the endpoint. /// Additional content types accepted by the endpoint. - public AcceptsAttribute(string contentType = "application/json", params string[] additionalContentTypes) + public AcceptsAttribute(global::System.Type type, string contentType = "application/json", params string[] additionalContentTypes) { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; + Type = type; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes; } } @@ -58,14 +60,9 @@ internal sealed class AcceptsAttribute : global::System.Attribute internal sealed class AcceptsAttribute : global::System.Attribute { /// - /// Gets the request type accepted by the endpoint. + /// Gets the CLR type of the endpoint filter. /// - public global::System.Type RequestType => typeof(TRequest); - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } + public global::System.Type Type => typeof(TRequest); /// /// Gets the primary content type accepted by the endpoint. @@ -77,6 +74,11 @@ internal sealed class AcceptsAttribute : global::System.Attribute /// public string[] AdditionalContentTypes { get; } + /// + /// Gets a value indicating whether the request body is optional. + /// + public bool IsOptional { get; init; } + /// /// Initializes a new instance of the generic Accepts attribute class. /// @@ -84,7 +86,7 @@ internal sealed class AcceptsAttribute : global::System.Attribute /// Additional content types accepted by the endpoint. public AcceptsAttribute(string contentType = "application/json", params string[] additionalContentTypes) { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes; } } diff --git a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.EndpointFilterAttribute.verified.txt b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.EndpointFilterAttribute.verified.txt index e96674f..7e69fab 100644 --- a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.EndpointFilterAttribute.verified.txt +++ b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.EndpointFilterAttribute.verified.txt @@ -21,15 +21,15 @@ internal sealed class EndpointFilterAttribute : global::System.Attribute /// /// Gets the CLR type of the endpoint filter. /// - public global::System.Type FilterType { get; } + public global::System.Type Type { get; } /// /// Initializes a new instance of the class. /// - /// The CLR type of the endpoint filter. - public EndpointFilterAttribute(global::System.Type filterType) + /// The CLR type of the endpoint filter. + public EndpointFilterAttribute(global::System.Type type) { - FilterType = filterType ?? throw new global::System.ArgumentNullException(nameof(filterType)); + Type = type; } } @@ -43,5 +43,5 @@ internal sealed class EndpointFilterAttribute : global::System.Attribut /// /// Gets the CLR type of the endpoint filter. /// - public global::System.Type FilterType => typeof(TFilter); + public global::System.Type Type => typeof(TFilter); } diff --git a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.ProducesResponseAttribute.verified.txt b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.ProducesResponseAttribute.verified.txt index e5d6748..25e39fb 100644 --- a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.ProducesResponseAttribute.verified.txt +++ b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.ProducesResponseAttribute.verified.txt @@ -18,10 +18,10 @@ namespace Microsoft.AspNetCore.Generated.Attributes; [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] internal sealed class ProducesResponseAttribute : global::System.Attribute { -/// -/// Gets the response type produced by the endpoint. -/// -public global::System.Type? ResponseType { get; init; } + /// + /// Gets the response type produced by the endpoint. + /// + public global::System.Type Type { get; } /// /// Gets the HTTP status code returned by the endpoint. @@ -41,11 +41,13 @@ public global::System.Type? ResponseType { get; init; } /// /// Initializes a new instance of the class. /// + /// The CLR type of the response body. /// The HTTP status code returned by the endpoint. /// The primary content type produced by the endpoint. /// Additional content types produced by the endpoint. - public ProducesResponseAttribute(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) + public ProducesResponseAttribute(global::System.Type type, int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) { + Type = type; StatusCode = statusCode; ContentType = contentType; AdditionalContentTypes = additionalContentTypes ?? []; @@ -62,7 +64,7 @@ internal sealed class ProducesResponseAttribute : global::System.Attr /// /// Gets the response type produced by the endpoint. /// - public global::System.Type ResponseType => typeof(TResponse); + public global::System.Type Type => typeof(TResponse); /// /// Gets the HTTP status code returned by the endpoint. diff --git a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs index 1989896..76181a1 100644 --- a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs +++ b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs @@ -380,7 +380,7 @@ public static string BuildContractsAndBindingSource( { var secondProduces = string.IsNullOrWhiteSpace(producesContentType2) ? "" : $", \"{producesContentType2}\""; builder.AppendLine( - $" [ProducesResponse(200, \"{producesContentType1 ?? "application/json"}\"{secondProduces}, ResponseType = typeof(ResponseRecord))]" + $" [ProducesResponse(typeof(ResponseRecord), 200, \"{producesContentType1 ?? "application/json"}\"{secondProduces})]" ); } From 761acf6a3a396675467af9c8d003e622c38fb95d Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:52:20 -0500 Subject: [PATCH 5/6] Cleanup. --- .../AddEndpointHandlersGenerator.cs | 5 +- .../Common/EndpointConfigurationFactory.cs | 15 +- .../Common/MethodSymbolExtensions.cs | 26 +- .../Common/NamedTypeSymbolExtensions.cs | 244 ---------- .../Common/TypeSymbolExtensions.cs | 438 ++++++++++++++++++ .../UseEndpointHandlersGenerator.cs | 137 ++++-- .../Common/SourceFactory.cs | 4 +- ...4B85A7155_MapEndpointHandlers.verified.txt | 1 - ...0CA9A1CFC_MapEndpointHandlers.verified.txt | 3 +- .../GeneratedSourceTests.cs | 10 +- .../IndividualTests.cs | 6 +- 11 files changed, 569 insertions(+), 320 deletions(-) delete mode 100644 src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs diff --git a/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs index bb258ee..6a6551f 100644 --- a/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs +++ b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs @@ -17,7 +17,10 @@ public static void GenerateSource(SourceProductionContext context, ImmutableSort { context.CancellationToken.ThrowIfCancellationRequested(); - var nonStaticClassNames = grouped.Keys.Where(x => !x.IsStatic).Select(x => x.Name).ToList(); + var nonStaticClassNames = grouped.Keys + .Where(x => !x.IsStatic) + .Select(x => x.Name) + .ToList(); var source = new StringBuilder(); source.AppendLine(FileHeader); diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index ea80621..f8dccd5 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -212,7 +212,8 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym requestType = attributeClass.TypeArguments[0] .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); else - requestType = attribute.GetConstructorTypeSymbol()?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + requestType = attribute.GetConstructorTypeSymbol() + ?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); if (requestType is null) return; @@ -234,16 +235,18 @@ private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSy { var responseType = attributeClass.TypeArguments[0] .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var statusCode = attribute.GetConstructorIntValue(position: 0) ?? 200; - var contentType = attribute.GetConstructorStringValue(position: 1); + var statusCode = attribute.GetConstructorIntValue(0) ?? 200; + var contentType = attribute.GetConstructorStringValue(1); var additionalContentTypes = attribute.GetConstructorStringArray(2); producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes); } else { - var responseType = attribute.GetConstructorTypeSymbol()?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) ?? ""; - var statusCode = attribute.GetConstructorIntValue(position: 1) ?? 200; - var contentType = attribute.GetConstructorStringValue(position: 2); + var responseType = attribute.GetConstructorTypeSymbol() + ?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + ?? ""; + var statusCode = attribute.GetConstructorIntValue(1) ?? 200; + var contentType = attribute.GetConstructorStringValue(2); var additionalContentTypes = attribute.GetConstructorStringArray(3); producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes); } diff --git a/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs b/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs index c5261e1..4d0acd9 100644 --- a/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs +++ b/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs @@ -1,6 +1,5 @@ using System.Collections.Immutable; using Microsoft.CodeAnalysis; -using static GeneratedEndpoints.Common.AttributeSymbolMatcher; using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints.Common; @@ -45,7 +44,7 @@ private static string GetBindingPrefix(IParameterSymbol parameter) if (attributeClass is null) continue; - var attributeSource = GetBindingSourceFromAttributeClass(attributeClass); + var attributeSource = attributeClass.GetBindingSource(); if (attributeSource == BindingSource.None) continue; @@ -69,29 +68,6 @@ private static string GetBindingPrefix(IParameterSymbol parameter) return bindingPrefix; } - internal static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"]; - internal static readonly string[] AspNetCoreMvcNamespaceParts = ["Microsoft", "AspNetCore", "Mvc"]; - internal static readonly string[] ExtensionsDependencyInjectionNamespaceParts = ["Microsoft", "Extensions", "DependencyInjection"]; - - private static BindingSource GetBindingSourceFromAttributeClass(INamedTypeSymbol attributeClass) - { - var definition = attributeClass.OriginalDefinition; - var namespaceSymbol = definition.ContainingNamespace; - - return definition.Name switch - { - "FromRouteAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromRoute, - "FromQueryAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromQuery, - "FromHeaderAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromHeader, - "FromBodyAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromBody, - "FromFormAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromForm, - "FromServicesAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromServices, - "FromKeyedServicesAttribute" when IsInNamespace(namespaceSymbol, ExtensionsDependencyInjectionNamespaceParts) => BindingSource.FromKeyedServices, - "AsParametersAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreHttpNamespaceParts) => BindingSource.AsParameters, - _ => BindingSource.None, - }; - } - private static string GetBindingSourceAttribute(BindingSource source, TypedConstant? typedKey, string? bindingName) { switch (source) diff --git a/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs b/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs deleted file mode 100644 index 45d6211..0000000 --- a/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs +++ /dev/null @@ -1,244 +0,0 @@ -using Microsoft.CodeAnalysis; - -namespace GeneratedEndpoints.Common; - -/// Provides extension methods for working with named type symbols. -internal static class NamedTypeSymbolExtensions -{ - public static RequestHandlerAttributeKind GetRequestHandlerAttributeKind(this ITypeSymbol definition) - { - return definition switch - { - { - MetadataName: "DisplayNameAttribute", - ContainingNamespace: - { - Name: "ComponentModel", - ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true }, - }, - } => RequestHandlerAttributeKind.DisplayName, - { - MetadataName: "DescriptionAttribute", - ContainingNamespace: - { - Name: "ComponentModel", - ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true }, - }, - } => RequestHandlerAttributeKind.Description, - { - MetadataName: "AllowAnonymousAttribute", - ContainingNamespace: - { - Name: "Authorization", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}, - }, - } => RequestHandlerAttributeKind.AllowAnonymous, - { - MetadataName: "TagsAttribute", - ContainingNamespace: - { - Name: "Http", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}, - }, - } => RequestHandlerAttributeKind.Tags, - { - MetadataName: "ExcludeFromDescriptionAttribute", - ContainingNamespace: - { - Name: "Routing", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}, - }, - } => RequestHandlerAttributeKind.ExcludeFromDescription, - { - MetadataName: "ExcludeFromDescriptionAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.ExcludeFromDescription, - { - MetadataName: "ShortCircuitAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.ShortCircuit, - { - MetadataName: "DisableValidationAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.DisableValidation, - { - MetadataName: "DisableRequestTimeoutAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.DisableRequestTimeout, - { - MetadataName: "RequestTimeoutAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.RequestTimeout, - { - MetadataName: "OrderAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.Order, - { - MetadataName: "MapGroupAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.MapGroup, - { - MetadataName: "SummaryAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.Summary, - { - MetadataName: "AcceptsAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.Accepts, - { - MetadataName: "AcceptsAttribute`1", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.Accepts, - { - MetadataName: "ProducesResponseAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.ProducesResponse, - { - MetadataName: "ProducesResponseAttribute`1", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.ProducesResponse, - { - MetadataName: "RequireAuthorizationAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.RequireAuthorization, - { - MetadataName: "RequireCorsAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.RequireCors, - { - MetadataName: "RequireHostAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.RequireHost, - { - MetadataName: "RequireRateLimitingAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.RequireRateLimiting, - { - MetadataName: "EndpointFilterAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.EndpointFilter, - { - MetadataName: "EndpointFilterAttribute`1", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.EndpointFilter, - { - MetadataName: "DisableAntiforgeryAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.DisableAntiforgery, - { - MetadataName: "ProducesProblemAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.ProducesProblem, - { - MetadataName: "ProducesValidationProblemAttribute", - ContainingNamespace: - { - Name: "Attributes", - ContainingNamespace: { Name: "Generated", - ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true }}}, - }, - } => RequestHandlerAttributeKind.ProducesValidationProblem, - _ => RequestHandlerAttributeKind.None, - }; - } -} diff --git a/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs b/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs index 7ad9d5e..c7dcbe4 100644 --- a/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs +++ b/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs @@ -4,6 +4,444 @@ namespace GeneratedEndpoints.Common; internal static class TypeSymbolExtensions { + public static RequestHandlerAttributeKind GetRequestHandlerAttributeKind(this ITypeSymbol symbol) + { + var definition = symbol.OriginalDefinition; + return definition switch + { + { + MetadataName: "DisplayNameAttribute", + ContainingNamespace: + { + Name: "ComponentModel", + ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true }, + }, + } => RequestHandlerAttributeKind.DisplayName, + { + MetadataName: "DescriptionAttribute", + ContainingNamespace: + { + Name: "ComponentModel", + ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true }, + }, + } => RequestHandlerAttributeKind.Description, + { + MetadataName: "AllowAnonymousAttribute", + ContainingNamespace: + { + Name: "Authorization", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + } => RequestHandlerAttributeKind.AllowAnonymous, + { + MetadataName: "TagsAttribute", + ContainingNamespace: + { + Name: "Http", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + } => RequestHandlerAttributeKind.Tags, + { + MetadataName: "ExcludeFromDescriptionAttribute", + ContainingNamespace: + { + Name: "Routing", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + } => RequestHandlerAttributeKind.ExcludeFromDescription, + { + MetadataName: "ExcludeFromDescriptionAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ExcludeFromDescription, + { + MetadataName: "ShortCircuitAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ShortCircuit, + { + MetadataName: "DisableValidationAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.DisableValidation, + { + MetadataName: "DisableRequestTimeoutAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.DisableRequestTimeout, + { + MetadataName: "RequestTimeoutAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.RequestTimeout, + { + MetadataName: "OrderAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.Order, + { + MetadataName: "MapGroupAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.MapGroup, + { + MetadataName: "SummaryAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.Summary, + { + MetadataName: "AcceptsAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.Accepts, + { + MetadataName: "AcceptsAttribute`1", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.Accepts, + { + MetadataName: "ProducesResponseAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ProducesResponse, + { + MetadataName: "ProducesResponseAttribute`1", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ProducesResponse, + { + MetadataName: "RequireAuthorizationAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.RequireAuthorization, + { + MetadataName: "RequireCorsAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.RequireCors, + { + MetadataName: "RequireHostAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.RequireHost, + { + MetadataName: "RequireRateLimitingAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.RequireRateLimiting, + { + MetadataName: "EndpointFilterAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.EndpointFilter, + { + MetadataName: "EndpointFilterAttribute`1", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.EndpointFilter, + { + MetadataName: "DisableAntiforgeryAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.DisableAntiforgery, + { + MetadataName: "ProducesProblemAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ProducesProblem, + { + MetadataName: "ProducesValidationProblemAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ProducesValidationProblem, + _ => RequestHandlerAttributeKind.None, + }; + } + + public static BindingSource GetBindingSource(this ITypeSymbol symbol) + { + var definition = symbol.OriginalDefinition; + return definition switch + { + { + MetadataName: "FromRouteAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromRoute, + { + MetadataName: "FromQueryAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromQuery, + { + MetadataName: "FromHeaderAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromHeader, + { + MetadataName: "FromBodyAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromBody, + { + MetadataName: "FromFormAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromForm, + { + MetadataName: "FromServicesAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromServices, + { + MetadataName: "FromKeyedServicesAttribute", + ContainingNamespace: + { + Name: "DependencyInjection", + ContainingNamespace: + { + Name: "Extensions", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromKeyedServices, + { + MetadataName: "AsParametersAttribute", + ContainingNamespace: + { + Name: "Http", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.AsParameters, + _ => BindingSource.None, + }; + } + public static bool IsIEndpointConventionBuilder(this ITypeSymbol symbol) { return symbol is diff --git a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs index 4c9f37b..8711e7f 100644 --- a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs +++ b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs @@ -17,7 +17,9 @@ public static void GenerateSource(SourceProductionContext context, ImmutableSort { context.CancellationToken.ThrowIfCancellationRequested(); - var requestHandlers = grouped.Values.SelectMany(x => x).ToImmutableList(); + var requestHandlers = grouped.Values + .SelectMany(x => x) + .ToImmutableList(); var source = new StringBuilder(); source.AppendLine(FileHeader); @@ -477,52 +479,33 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu var displayName = methodConfiguration.DisplayName ?? classConfiguration.DisplayName; var summary = methodConfiguration.Summary ?? classConfiguration.Summary; var description = methodConfiguration.Description ?? classConfiguration.Description; - var tags = MergeDistinctStrings(methodConfiguration.Tags, classConfiguration.Tags); + var excludeFromDescription = methodConfiguration.ExcludeFromDescription || classConfiguration.ExcludeFromDescription; + var accepts = ConcatEquatable(methodConfiguration.Accepts, classConfiguration.Accepts); var produces = ConcatEquatable(methodConfiguration.Produces, classConfiguration.Produces); var producesProblem = ConcatEquatable(methodConfiguration.ProducesProblem, classConfiguration.ProducesProblem); var producesValidationProblem = ConcatEquatable(methodConfiguration.ProducesValidationProblem, classConfiguration.ProducesValidationProblem); - var excludeFromDescription = methodConfiguration.ExcludeFromDescription || classConfiguration.ExcludeFromDescription; - - var authorizationPolicies = MergeDistinctStrings(methodConfiguration.AuthorizationPolicies, classConfiguration.AuthorizationPolicies); + var shortCircuit = methodConfiguration.ShortCircuit || classConfiguration.ShortCircuit; + var order = methodConfiguration.Order ?? classConfiguration.Order; + var disableAntiforgery = methodConfiguration.DisableAntiforgery || classConfiguration.DisableAntiforgery; + var disableValidation = methodConfiguration.DisableValidation || classConfiguration.DisableValidation; var requiredHosts = MergeDistinctStrings(methodConfiguration.RequiredHosts, classConfiguration.RequiredHosts); var endpointFilterTypes = ConcatEquatable(methodConfiguration.EndpointFilterTypes, classConfiguration.EndpointFilterTypes); - var requireAuthorization = methodConfiguration.RequireAuthorization || classConfiguration.RequireAuthorization; - var disableAntiforgery = methodConfiguration.DisableAntiforgery || classConfiguration.DisableAntiforgery; - var allowAnonymous = methodConfiguration.AllowAnonymous || classConfiguration.AllowAnonymous; - - var requireCors = methodConfiguration.RequireCors || classConfiguration.RequireCors; - var corsPolicyName = methodConfiguration.CorsPolicyName ?? classConfiguration.CorsPolicyName; + var (allowAnonymous, requireAuthorization) = ResolveAuthorization(methodConfiguration, classConfiguration); + var authorizationPolicies = MergeDistinctStrings(methodConfiguration.AuthorizationPolicies, classConfiguration.AuthorizationPolicies); - var requireRateLimiting = methodConfiguration.RequireRateLimiting || classConfiguration.RequireRateLimiting; - var rateLimitingPolicyName = methodConfiguration.RateLimitingPolicyName ?? classConfiguration.RateLimitingPolicyName; + var (requireCors, corsPolicyName) = ResolveCors(methodConfiguration, classConfiguration); + var (requireRateLimiting, rateLimitingPolicyName) = ResolveRateLimiting(methodConfiguration, classConfiguration); - var shortCircuit = methodConfiguration.ShortCircuit || classConfiguration.ShortCircuit; - var disableValidation = methodConfiguration.DisableValidation || classConfiguration.DisableValidation; - var disableRequestTimeout = methodConfiguration.DisableRequestTimeout || classConfiguration.DisableRequestTimeout; - var withRequestTimeout = methodConfiguration.WithRequestTimeout || classConfiguration.WithRequestTimeout; + var (disableRequestTimeout, withRequestTimeout, requestTimeoutPolicyName) = ResolveRequestTimeout(methodConfiguration, classConfiguration); var groupIdentifier = methodConfiguration.Group?.Identifier ?? classConfiguration.Group?.Identifier; var groupPattern = methodConfiguration.Group?.Pattern ?? classConfiguration.Group?.Pattern; var groupName = methodConfiguration.Group?.Name ?? classConfiguration.Group?.Name; - string? requestTimeoutPolicyName = null; - if (methodConfiguration.WithRequestTimeout) - requestTimeoutPolicyName = methodConfiguration.RequestTimeoutPolicyName; - else if (classConfiguration.WithRequestTimeout) - requestTimeoutPolicyName = classConfiguration.RequestTimeoutPolicyName; - - if (disableRequestTimeout) - { - withRequestTimeout = false; - requestTimeoutPolicyName = null; - } - - var order = methodConfiguration.Order ?? classConfiguration.Order; - return new EndpointConfiguration { DisplayName = displayName, @@ -561,6 +544,98 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu }; } + private static (bool AllowAnonymous, bool RequireAuthorization) ResolveAuthorization( + EndpointConfiguration methodConfiguration, + EndpointConfiguration classConfiguration + ) + { + var methodReq = methodConfiguration.RequireAuthorization; + var methodAnon = !methodReq && methodConfiguration.AllowAnonymous; + + var classReq = classConfiguration.RequireAuthorization; + var classAnon = !classReq && classConfiguration.AllowAnonymous; + + var methodDeclares = methodConfiguration.AllowAnonymous || methodConfiguration.RequireAuthorization; + + if (methodDeclares) + { + // Method directive wins + if (methodReq) + return (AllowAnonymous: false, RequireAuthorization: true); + + if (methodAnon) + return (AllowAnonymous: true, RequireAuthorization: false); + + return (false, false); + } + + if (classReq) + return (AllowAnonymous: false, RequireAuthorization: true); + + if (classAnon) + return (AllowAnonymous: true, RequireAuthorization: false); + + return (AllowAnonymous: false, RequireAuthorization: false); + } + + private static (bool DisableRequestTimeout, bool WithRequestTimeout, string? RequestTimeoutPolicyName) ResolveRequestTimeout( + EndpointConfiguration methodConfiguration, + EndpointConfiguration classConfiguration + ) + { + var methodWith = methodConfiguration.WithRequestTimeout; + var methodDisable = !methodWith && methodConfiguration.DisableRequestTimeout; + + var classWith = classConfiguration.WithRequestTimeout; + var classDisable = !classWith && classConfiguration.DisableRequestTimeout; + + var methodDeclares = methodConfiguration.DisableRequestTimeout || methodConfiguration.WithRequestTimeout; + + if (methodDeclares) + { + if (methodWith) + return (DisableRequestTimeout: false, WithRequestTimeout: true, methodConfiguration.RequestTimeoutPolicyName); + + if (methodDisable) + return (DisableRequestTimeout: true, WithRequestTimeout: false, null); + + return (false, false, null); + } + + if (classWith) + return (DisableRequestTimeout: false, WithRequestTimeout: true, classConfiguration.RequestTimeoutPolicyName); + + if (classDisable) + return (DisableRequestTimeout: true, WithRequestTimeout: false, null); + + return (DisableRequestTimeout: false, WithRequestTimeout: false, null); + } + + private static (bool RequireCors, string? CorsPolicyName) ResolveCors(EndpointConfiguration methodConfiguration, EndpointConfiguration classConfiguration) + { + if (methodConfiguration.RequireCors) + return (RequireCors: true, methodConfiguration.CorsPolicyName); + + if (classConfiguration.RequireCors) + return (RequireCors: true, classConfiguration.CorsPolicyName); + + return (RequireCors: false, CorsPolicyName: null); + } + + private static (bool RequireRateLimiting, string? RateLimitingPolicyName) ResolveRateLimiting( + EndpointConfiguration methodConfiguration, + EndpointConfiguration classConfiguration + ) + { + if (methodConfiguration.RequireRateLimiting) + return (RequireRateLimiting: true, methodConfiguration.RateLimitingPolicyName); + + if (classConfiguration.RequireRateLimiting) + return (RequireRateLimiting: true, classConfiguration.RateLimitingPolicyName); + + return (RequireRateLimiting: false, RateLimitingPolicyName: null); + } + private static EquatableImmutableArray? MergeDistinctStrings(EquatableImmutableArray? first, EquatableImmutableArray? second) { if (first is not { Count: > 0 }) diff --git a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs index 76181a1..035ce8d 100644 --- a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs +++ b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs @@ -379,9 +379,7 @@ public static string BuildContractsAndBindingSource( if (includeProducesResponse) { var secondProduces = string.IsNullOrWhiteSpace(producesContentType2) ? "" : $", \"{producesContentType2}\""; - builder.AppendLine( - $" [ProducesResponse(typeof(ResponseRecord), 200, \"{producesContentType1 ?? "application/json"}\"{secondProduces})]" - ); + builder.AppendLine($" [ProducesResponse(typeof(ResponseRecord), 200, \"{producesContentType1 ?? "application/json"}\"{secondProduces})]"); } if (includeGenericProducesResponse) diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.AuthorizationAndMetadataMatrix_6F94B85A7155_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.AuthorizationAndMetadataMatrix_6F94B85A7155_MapEndpointHandlers.verified.txt index ed92292..ed5e8f5 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.AuthorizationAndMetadataMatrix_6F94B85A7155_MapEndpointHandlers.verified.txt +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.AuthorizationAndMetadataMatrix_6F94B85A7155_MapEndpointHandlers.verified.txt @@ -29,7 +29,6 @@ internal static class EndpointRouteBuilderExtensions .RequireAuthorization("MethodPolicy") .RequireCors("MethodCors") .RequireHost("services.contoso.com", "contoso.com") - .AllowAnonymous() .DisableRequestTimeout(); return builder; diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_8330CA9A1CFC_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_8330CA9A1CFC_MapEndpointHandlers.verified.txt index 493d722..311982e 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_8330CA9A1CFC_MapEndpointHandlers.verified.txt +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_8330CA9A1CFC_MapEndpointHandlers.verified.txt @@ -29,8 +29,7 @@ internal static class EndpointRouteBuilderExtensions .WithDescription("Shows binding and contract combinations.") .Accepts("application/json") .ProducesValidationProblem(422, "application/problem+json") - .RequireAuthorization("ContractsPolicy") - .AllowAnonymous(); + .RequireAuthorization("ContractsPolicy"); return builder; } diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs index ebf6aa6..725fa30 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs @@ -181,7 +181,9 @@ await result.VerifyAsync("MapEndpointHandlers.g.cs") [InlineData(false, true, false, true, false, true, false, true, false, true, false, false, true, false, true, false, true, "application/xml", null, "application/json", null )] - [InlineData(true, false, true, false, true, false, true, false, true, false, true, false, true, true, false, true, false, null, "text/plain", null, "text/plain")] + [InlineData(true, false, true, false, true, false, true, false, true, false, true, false, true, true, false, true, false, null, "text/plain", null, + "text/plain" + )] public async Task ContractsAndBindingMatrix( bool withNamespace, bool includeBindingNames, @@ -207,9 +209,9 @@ public async Task ContractsAndBindingMatrix( ) { var source = SourceFactory.BuildContractsAndBindingSource(includeBindingNames, includeAsParameters, includeFromServices, includeFromKeyedServices, - includeAccepts, includeGenericAccepts, includeProducesResponse, includeGenericProducesResponse, includeProducesProblem, includeProducesValidationProblem, - includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, methodRequiresAuthorization, - acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 + includeAccepts, includeGenericAccepts, includeProducesResponse, includeGenericProducesResponse, includeProducesProblem, + includeProducesValidationProblem, includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, + methodRequiresAuthorization, acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 ); var sources = TestHelpers.GetSources(source, withNamespace); diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.cs b/tests/GeneratedEndpoints.Tests/IndividualTests.cs index 55adbde..bd0d6c3 100644 --- a/tests/GeneratedEndpoints.Tests/IndividualTests.cs +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.cs @@ -569,9 +569,9 @@ private static string ContractScenario( ) { return SourceFactory.BuildContractsAndBindingSource(includeBindingNames, includeAsParameters, includeFromServices, includeFromKeyedServices, - includeAccepts, includeGenericAccepts, includeProducesResponse, includeGenericProducesResponse, includeProducesProblem, includeProducesValidationProblem, - includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, methodRequiresAuthorization, - acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 + includeAccepts, includeGenericAccepts, includeProducesResponse, includeGenericProducesResponse, includeProducesProblem, + includeProducesValidationProblem, includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, + methodRequiresAuthorization, acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 ); } From 81c3ec9ccef391405bae6debd6ec17e0738ef4f7 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:52:46 -0500 Subject: [PATCH 6/6] Bump version. --- src/GeneratedEndpoints/GeneratedEndpoints.csproj | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/GeneratedEndpoints/GeneratedEndpoints.csproj b/src/GeneratedEndpoints/GeneratedEndpoints.csproj index ca27350..efddba0 100644 --- a/src/GeneratedEndpoints/GeneratedEndpoints.csproj +++ b/src/GeneratedEndpoints/GeneratedEndpoints.csproj @@ -36,9 +36,9 @@ GeneratedEndpoints - 10.0.1 - 10.0.1.0 - 10.0.1.0 + 10.0.2 + 10.0.2.0 + 10.0.2.0 en-US false