From 3b122eef5238c390cd1044b8ec428ab2c946eabc Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 15:17:57 -0500 Subject: [PATCH 01/32] Move nested generator types to Common (#61) --- .../Common/AcceptsMetadata.cs | 8 + .../Common/AttributeSymbolMatcher.cs | 25 + .../Common/BindingSource.cs | 14 + .../Common/CompilationTypeCache.cs | 11 + .../Common/ConfigureMethodDetails.cs | 6 + .../Constants.GeneratedSources.cs} | 198 ++--- src/GeneratedEndpoints/Common/Constants.cs | 107 +++ .../Common/EndpointAttributeState.cs | 34 + .../Common/EndpointConfiguration.cs | 22 + .../Common/EndpointConfigurationFactory.cs | 531 +++++++++++++ .../Common/GeneratedAttributeKind.cs | 23 + .../Common/HttpAttributeDefinition.cs | 11 + src/GeneratedEndpoints/Common/Parameter.cs | 3 + .../Common/ProducesMetadata.cs | 8 + .../Common/ProducesProblemMetadata.cs | 7 + .../ProducesValidationProblemMetadata.cs | 7 + .../Common/RequestHandler.cs | 9 + .../Common/RequestHandlerClass.cs | 11 + .../Common/RequestHandlerClassCacheEntry.cs | 55 ++ .../Common/RequestHandlerComparer.cs | 26 + .../Common/RequestHandlerMetadata.cs | 14 + .../Common/RequestHandlerMethod.cs | 8 + .../Common/RequestHandlerParameterHelper.cs | 131 ++++ .../MinimalApiGenerator.Types.cs | 234 ------ src/GeneratedEndpoints/MinimalApiGenerator.cs | 701 +----------------- 25 files changed, 1155 insertions(+), 1049 deletions(-) create mode 100644 src/GeneratedEndpoints/Common/AcceptsMetadata.cs create mode 100644 src/GeneratedEndpoints/Common/AttributeSymbolMatcher.cs create mode 100644 src/GeneratedEndpoints/Common/BindingSource.cs create mode 100644 src/GeneratedEndpoints/Common/CompilationTypeCache.cs create mode 100644 src/GeneratedEndpoints/Common/ConfigureMethodDetails.cs rename src/GeneratedEndpoints/{MinimalApiGenerator.Constants.cs => Common/Constants.GeneratedSources.cs} (87%) create mode 100644 src/GeneratedEndpoints/Common/Constants.cs create mode 100644 src/GeneratedEndpoints/Common/EndpointAttributeState.cs create mode 100644 src/GeneratedEndpoints/Common/EndpointConfiguration.cs create mode 100644 src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs create mode 100644 src/GeneratedEndpoints/Common/GeneratedAttributeKind.cs create mode 100644 src/GeneratedEndpoints/Common/HttpAttributeDefinition.cs create mode 100644 src/GeneratedEndpoints/Common/Parameter.cs create mode 100644 src/GeneratedEndpoints/Common/ProducesMetadata.cs create mode 100644 src/GeneratedEndpoints/Common/ProducesProblemMetadata.cs create mode 100644 src/GeneratedEndpoints/Common/ProducesValidationProblemMetadata.cs create mode 100644 src/GeneratedEndpoints/Common/RequestHandler.cs create mode 100644 src/GeneratedEndpoints/Common/RequestHandlerClass.cs create mode 100644 src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs create mode 100644 src/GeneratedEndpoints/Common/RequestHandlerComparer.cs create mode 100644 src/GeneratedEndpoints/Common/RequestHandlerMetadata.cs create mode 100644 src/GeneratedEndpoints/Common/RequestHandlerMethod.cs create mode 100644 src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs delete mode 100644 src/GeneratedEndpoints/MinimalApiGenerator.Types.cs diff --git a/src/GeneratedEndpoints/Common/AcceptsMetadata.cs b/src/GeneratedEndpoints/Common/AcceptsMetadata.cs new file mode 100644 index 0000000..3656580 --- /dev/null +++ b/src/GeneratedEndpoints/Common/AcceptsMetadata.cs @@ -0,0 +1,8 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct AcceptsMetadata( + string RequestType, + string ContentType, + EquatableImmutableArray? AdditionalContentTypes, + bool IsOptional +); diff --git a/src/GeneratedEndpoints/Common/AttributeSymbolMatcher.cs b/src/GeneratedEndpoints/Common/AttributeSymbolMatcher.cs new file mode 100644 index 0000000..fc5dc15 --- /dev/null +++ b/src/GeneratedEndpoints/Common/AttributeSymbolMatcher.cs @@ -0,0 +1,25 @@ +using Microsoft.CodeAnalysis; + +namespace GeneratedEndpoints.Common; + +internal static class AttributeSymbolMatcher +{ + public static bool IsAttribute(INamedTypeSymbol attributeClass, string attributeName, string[] namespaceParts) + { + var definition = attributeClass.OriginalDefinition; + return definition.Name == attributeName && IsInNamespace(definition.ContainingNamespace, namespaceParts); + } + + public static bool IsInNamespace(INamespaceSymbol? namespaceSymbol, string[] namespaceParts) + { + for (var i = namespaceParts.Length - 1; i >= 0; i--) + { + if (namespaceSymbol is null || namespaceSymbol.Name != namespaceParts[i]) + return false; + + namespaceSymbol = namespaceSymbol.ContainingNamespace; + } + + return namespaceSymbol is null || namespaceSymbol.IsGlobalNamespace; + } +} diff --git a/src/GeneratedEndpoints/Common/BindingSource.cs b/src/GeneratedEndpoints/Common/BindingSource.cs new file mode 100644 index 0000000..4db1144 --- /dev/null +++ b/src/GeneratedEndpoints/Common/BindingSource.cs @@ -0,0 +1,14 @@ +namespace GeneratedEndpoints.Common; + +internal enum BindingSource +{ + None = 0, + FromRoute = 1, + FromQuery = 2, + FromHeader = 3, + FromBody = 4, + FromForm = 5, + FromServices = 6, + FromKeyedServices = 7, + AsParameters = 8, +} diff --git a/src/GeneratedEndpoints/Common/CompilationTypeCache.cs b/src/GeneratedEndpoints/Common/CompilationTypeCache.cs new file mode 100644 index 0000000..0567f3a --- /dev/null +++ b/src/GeneratedEndpoints/Common/CompilationTypeCache.cs @@ -0,0 +1,11 @@ +using Microsoft.CodeAnalysis; + +namespace GeneratedEndpoints.Common; + +internal sealed class CompilationTypeCache(Compilation compilation) +{ + public INamedTypeSymbol? EndpointConventionBuilderSymbol { get; } = + compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Builder.IEndpointConventionBuilder"); + + public INamedTypeSymbol? ServiceProviderSymbol { get; } = compilation.GetTypeByMetadataName("System.IServiceProvider"); +} diff --git a/src/GeneratedEndpoints/Common/ConfigureMethodDetails.cs b/src/GeneratedEndpoints/Common/ConfigureMethodDetails.cs new file mode 100644 index 0000000..0e3b859 --- /dev/null +++ b/src/GeneratedEndpoints/Common/ConfigureMethodDetails.cs @@ -0,0 +1,6 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct ConfigureMethodDetails( + bool HasConfigureMethod, + bool ConfigureMethodAcceptsServiceProvider +); diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.Constants.cs b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs similarity index 87% rename from src/GeneratedEndpoints/MinimalApiGenerator.Constants.cs rename to src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs index 0f51419..b2ef684 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.Constants.cs +++ b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs @@ -1,122 +1,13 @@ using System.Collections.Immutable; -using System.Runtime.CompilerServices; using System.Text; -using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Text; +using GeneratedEndpoints; -namespace GeneratedEndpoints; +namespace GeneratedEndpoints.Common; -public sealed partial class MinimalApiGenerator +internal static partial class Constants { - private const string BaseNamespace = "Microsoft.AspNetCore.Generated"; - private const string AttributesNamespace = $"{BaseNamespace}.Attributes"; - - private const string FallbackHttpMethod = "__FALLBACK__"; - - private const string NameAttributeNamedParameter = "Name"; - private const string ResponseTypeAttributeNamedParameter = "ResponseType"; - private const string RequestTypeAttributeNamedParameter = "RequestType"; - private const string IsOptionalAttributeNamedParameter = "IsOptional"; - private const string PolicyNameAttributeNamedParameter = "PolicyName"; - - private const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute"; - private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}"; - private const string RequireAuthorizationAttributeHint = $"{RequireAuthorizationAttributeFullyQualifiedName}.gs.cs"; - - private const string RequireCorsAttributeName = "RequireCorsAttribute"; - private const string RequireCorsAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireCorsAttributeName}"; - private const string RequireCorsAttributeHint = $"{RequireCorsAttributeFullyQualifiedName}.gs.cs"; - - private const string RequireRateLimitingAttributeName = "RequireRateLimitingAttribute"; - private const string RequireRateLimitingAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireRateLimitingAttributeName}"; - private const string RequireRateLimitingAttributeHint = $"{RequireRateLimitingAttributeFullyQualifiedName}.gs.cs"; - - private const string RequireHostAttributeName = "RequireHostAttribute"; - private const string RequireHostAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireHostAttributeName}"; - private const string RequireHostAttributeHint = $"{RequireHostAttributeFullyQualifiedName}.gs.cs"; - - private const string DisableAntiforgeryAttributeName = "DisableAntiforgeryAttribute"; - private const string DisableAntiforgeryAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableAntiforgeryAttributeName}"; - private const string DisableAntiforgeryAttributeHint = $"{DisableAntiforgeryAttributeFullyQualifiedName}.gs.cs"; - - private const string ShortCircuitAttributeName = "ShortCircuitAttribute"; - private const string ShortCircuitAttributeFullyQualifiedName = $"{AttributesNamespace}.{ShortCircuitAttributeName}"; - private const string ShortCircuitAttributeHint = $"{ShortCircuitAttributeFullyQualifiedName}.gs.cs"; - - private const string DisableRequestTimeoutAttributeName = "DisableRequestTimeoutAttribute"; - private const string DisableRequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableRequestTimeoutAttributeName}"; - private const string DisableRequestTimeoutAttributeHint = $"{DisableRequestTimeoutAttributeFullyQualifiedName}.gs.cs"; - - private const string DisableValidationAttributeName = "DisableValidationAttribute"; - private const string DisableValidationAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableValidationAttributeName}"; - private const string DisableValidationAttributeHint = $"{DisableValidationAttributeFullyQualifiedName}.gs.cs"; - - private const string RequestTimeoutAttributeName = "RequestTimeoutAttribute"; - private const string RequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequestTimeoutAttributeName}"; - private const string RequestTimeoutAttributeHint = $"{RequestTimeoutAttributeFullyQualifiedName}.gs.cs"; - - private const string OrderAttributeName = "OrderAttribute"; - private const string OrderAttributeFullyQualifiedName = $"{AttributesNamespace}.{OrderAttributeName}"; - private const string OrderAttributeHint = $"{OrderAttributeFullyQualifiedName}.gs.cs"; - - private const string MapGroupAttributeName = "MapGroupAttribute"; - private const string MapGroupAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapGroupAttributeName}"; - private const string MapGroupAttributeHint = $"{MapGroupAttributeFullyQualifiedName}.gs.cs"; - - private const string SummaryAttributeName = "SummaryAttribute"; - private const string SummaryAttributeFullyQualifiedName = $"{AttributesNamespace}.{SummaryAttributeName}"; - private const string SummaryAttributeHint = $"{SummaryAttributeFullyQualifiedName}.gs.cs"; - - private const string AllowAnonymousAttributeName = "AllowAnonymousAttribute"; - - private const string EndpointFilterAttributeName = "EndpointFilterAttribute"; - private const string EndpointFilterAttributeFullyQualifiedName = $"{AttributesNamespace}.{EndpointFilterAttributeName}"; - private const string EndpointFilterAttributeHint = $"{EndpointFilterAttributeFullyQualifiedName}.gs.cs"; - - private const string AcceptsAttributeName = "AcceptsAttribute"; - private const string AcceptsAttributeFullyQualifiedName = $"{AttributesNamespace}.{AcceptsAttributeName}"; - private const string AcceptsAttributeHint = $"{AcceptsAttributeFullyQualifiedName}.gs.cs"; - - private const string ProducesResponseAttributeName = "ProducesResponseAttribute"; - private const string ProducesResponseAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesResponseAttributeName}"; - private const string ProducesResponseAttributeHint = $"{ProducesResponseAttributeFullyQualifiedName}.gs.cs"; - - private const string ProducesProblemAttributeName = "ProducesProblemAttribute"; - private const string ProducesProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesProblemAttributeName}"; - private const string ProducesProblemAttributeHint = $"{ProducesProblemAttributeFullyQualifiedName}.gs.cs"; - - private const string ProducesValidationProblemAttributeName = "ProducesValidationProblemAttribute"; - - private const string ProducesValidationProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesValidationProblemAttributeName}"; - - private const string ProducesValidationProblemAttributeHint = $"{ProducesValidationProblemAttributeFullyQualifiedName}.gs.cs"; - - private const string RoutingNamespace = $"{BaseNamespace}.Routing"; - - private const string AddEndpointHandlersClassName = "EndpointServicesExtensions"; - private const string AddEndpointHandlersMethodName = "AddEndpointHandlers"; - private const string AddEndpointHandlersMethodHint = $"{RoutingNamespace}.{AddEndpointHandlersMethodName}.g.cs"; - - private const string UseEndpointHandlersClassName = "EndpointRouteBuilderExtensions"; - private const string UseEndpointHandlersMethodName = "MapEndpointHandlers"; - private const string UseEndpointHandlersMethodHint = $"{RoutingNamespace}.{UseEndpointHandlersMethodName}.g.cs"; - - private const string ConfigureMethodName = "Configure"; - private const string AsyncSuffix = "Async"; - private const string GlobalPrefix = "global::"; - private static readonly string[] AttributesNamespaceParts = AttributesNamespace.Split('.'); - private static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"]; - private static readonly string[] AspNetCoreMvcNamespaceParts = ["Microsoft", "AspNetCore", "Mvc"]; - private static readonly string[] AspNetCoreAuthorizationNamespaceParts = ["Microsoft", "AspNetCore", "Authorization"]; - private static readonly string[] AspNetCoreRoutingNamespaceParts = ["Microsoft", "AspNetCore", "Routing"]; - private static readonly string[] ExtensionsDependencyInjectionNamespaceParts = - ["Microsoft", "Extensions", "DependencyInjection"]; - private static readonly string[] ComponentModelNamespaceParts = ["System", "ComponentModel"]; - private static readonly ConditionalWeakTable CompilationTypeCaches = new(); - private static readonly ConditionalWeakTable RequestHandlerClassCache = new(); - private static readonly ConditionalWeakTable GeneratedAttributeKindCache = new(); - - private static readonly string FileHeader = $""" + internal static readonly string FileHeader = $""" //----------------------------------------------------------------------------- // // This code was generated by {nameof(MinimalApiGenerator)} which can be found @@ -130,7 +21,7 @@ public sealed partial class MinimalApiGenerator #nullable enable """; - private static readonly ImmutableArray HttpAttributeDefinitions = + internal static readonly ImmutableArray HttpAttributeDefinitions = [ CreateHttpAttributeDefinition("MapGetAttribute", "GET"), CreateHttpAttributeDefinition("MapPostAttribute", "POST"), @@ -145,10 +36,10 @@ public sealed partial class MinimalApiGenerator CreateHttpAttributeDefinition("MapFallbackAttribute", FallbackHttpMethod, true), ]; - private static readonly ImmutableDictionary HttpAttributeDefinitionsByName = + internal static readonly ImmutableDictionary HttpAttributeDefinitionsByName = HttpAttributeDefinitions.ToImmutableDictionary(static definition => definition.Name); - private static readonly SourceText RequireAuthorizationAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText RequireAuthorizationAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -184,7 +75,7 @@ internal sealed class {{RequireAuthorizationAttributeName}} : global::System.Att """, Encoding.UTF8 ); - private static readonly SourceText RequireCorsAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText RequireCorsAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -218,7 +109,7 @@ internal sealed class {{RequireCorsAttributeName}} : global::System.Attribute """, Encoding.UTF8 ); - private static readonly SourceText RequireRateLimitingAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText RequireRateLimitingAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -246,7 +137,7 @@ internal sealed class {{RequireRateLimitingAttributeName}} : global::System.Attr """, Encoding.UTF8 ); - private static readonly SourceText RequireHostAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText RequireHostAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -274,7 +165,7 @@ internal sealed class {{RequireHostAttributeName}} : global::System.Attribute """, Encoding.UTF8 ); - private static readonly SourceText DisableAntiforgeryAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText DisableAntiforgeryAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -290,7 +181,7 @@ internal sealed class {{DisableAntiforgeryAttributeName}} : global::System.Attri """, Encoding.UTF8 ); - private static readonly SourceText ShortCircuitAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText ShortCircuitAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -306,7 +197,7 @@ internal sealed class {{ShortCircuitAttributeName}} : global::System.Attribute """, Encoding.UTF8 ); - private static readonly SourceText DisableRequestTimeoutAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText DisableRequestTimeoutAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -322,7 +213,7 @@ internal sealed class {{DisableRequestTimeoutAttributeName}} : global::System.At """, Encoding.UTF8 ); - private static readonly SourceText DisableValidationAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText DisableValidationAttributeSourceText = SourceText.From($$""" #if NET10_0_OR_GREATER {{FileHeader}} @@ -340,7 +231,7 @@ internal sealed class {{DisableValidationAttributeName}} : global::System.Attrib """, Encoding.UTF8 ); - private static readonly SourceText RequestTimeoutAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText RequestTimeoutAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -376,7 +267,7 @@ internal sealed class {{RequestTimeoutAttributeName}} : global::System.Attribute """, Encoding.UTF8 ); - private static readonly SourceText OrderAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText OrderAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -405,7 +296,7 @@ internal sealed class {{OrderAttributeName}} : global::System.Attribute """, Encoding.UTF8 ); - private static readonly SourceText MapGroupAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText MapGroupAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -439,7 +330,7 @@ internal sealed class {{MapGroupAttributeName}} : global::System.Attribute """, Encoding.UTF8 ); - private static readonly SourceText SummaryAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText SummaryAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -468,7 +359,7 @@ internal sealed class {{SummaryAttributeName}} : global::System.Attribute """, Encoding.UTF8 ); - private static readonly SourceText AcceptsAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText AcceptsAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -553,7 +444,7 @@ internal sealed class {{AcceptsAttributeName}} : global::System.Attrib """, Encoding.UTF8 ); - private static readonly SourceText EndpointFilterAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText EndpointFilterAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -595,7 +486,7 @@ internal sealed class {{EndpointFilterAttributeName}} : global::System. """, Encoding.UTF8 ); - private static readonly SourceText ProducesResponseAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText ProducesResponseAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -684,7 +575,7 @@ internal sealed class {{ProducesResponseAttributeName}} : global::Sys """, Encoding.UTF8 ); - private static readonly SourceText ProducesProblemAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText ProducesProblemAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -727,7 +618,7 @@ internal sealed class {{ProducesProblemAttributeName}} : global::System.Attribut """, Encoding.UTF8 ); - private static readonly SourceText ProducesValidationProblemAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText ProducesValidationProblemAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; @@ -770,4 +661,47 @@ internal sealed class {{ProducesValidationProblemAttributeName}} : global::Syste """, Encoding.UTF8 ); + private static HttpAttributeDefinition CreateHttpAttributeDefinition(string attributeName, string verb, bool allowOptionalPattern = false) + { + var fullyQualifiedName = $"{AttributesNamespace}.{attributeName}"; + var hint = $"{fullyQualifiedName}.gs.cs"; + var summaryVerb = verb == FallbackHttpMethod ? "fallback" : verb; + var source = GenerateHttpAttributeSource(AttributesNamespace, attributeName, summaryVerb, allowOptionalPattern); + return new HttpAttributeDefinition(attributeName, fullyQualifiedName, hint, verb, SourceText.From(source, Encoding.UTF8)); + } + + private static string GenerateHttpAttributeSource(string attributesNamespace, string attributeName, string summaryVerb, bool allowOptionalPattern = false) + { + return $$""" + {{FileHeader}} + + namespace {{attributesNamespace}}; + + /// + /// Identifies a method as an HTTP {{summaryVerb}} minimal API endpoint with the specified route pattern. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{attributeName}} : global::System.Attribute + { + /// + /// Gets the route pattern for the endpoint. + /// + public string{{(allowOptionalPattern ? "?" : "")}} Pattern { get; } + + /// + /// Gets or sets the endpoint name. + /// + public string? Name { get; init; } + + /// + /// Initializes a new instance of the class. + /// + /// The route pattern for the endpoint. + public {{attributeName}}([global::System.Diagnostics.CodeAnalysis.StringSyntax("Route")] string{{(allowOptionalPattern ? "?" : "")}} pattern{{(allowOptionalPattern ? " = null" : "")}}) + { + Pattern = pattern; + } + } + """; + } } diff --git a/src/GeneratedEndpoints/Common/Constants.cs b/src/GeneratedEndpoints/Common/Constants.cs new file mode 100644 index 0000000..a75c1c2 --- /dev/null +++ b/src/GeneratedEndpoints/Common/Constants.cs @@ -0,0 +1,107 @@ +namespace GeneratedEndpoints.Common; + +internal static partial class Constants +{ + internal const string BaseNamespace = "Microsoft.AspNetCore.Generated"; + internal const string AttributesNamespace = $"{BaseNamespace}.Attributes"; + + internal const string FallbackHttpMethod = "__FALLBACK__"; + + internal const string NameAttributeNamedParameter = "Name"; + internal const string ResponseTypeAttributeNamedParameter = "ResponseType"; + internal const string RequestTypeAttributeNamedParameter = "RequestType"; + internal const string IsOptionalAttributeNamedParameter = "IsOptional"; + internal const string PolicyNameAttributeNamedParameter = "PolicyName"; + + internal const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute"; + internal const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}"; + internal const string RequireAuthorizationAttributeHint = $"{RequireAuthorizationAttributeFullyQualifiedName}.gs.cs"; + + internal const string RequireCorsAttributeName = "RequireCorsAttribute"; + internal const string RequireCorsAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireCorsAttributeName}"; + internal const string RequireCorsAttributeHint = $"{RequireCorsAttributeFullyQualifiedName}.gs.cs"; + + internal const string RequireRateLimitingAttributeName = "RequireRateLimitingAttribute"; + internal const string RequireRateLimitingAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireRateLimitingAttributeName}"; + internal const string RequireRateLimitingAttributeHint = $"{RequireRateLimitingAttributeFullyQualifiedName}.gs.cs"; + + internal const string RequireHostAttributeName = "RequireHostAttribute"; + internal const string RequireHostAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireHostAttributeName}"; + internal const string RequireHostAttributeHint = $"{RequireHostAttributeFullyQualifiedName}.gs.cs"; + + internal const string DisableAntiforgeryAttributeName = "DisableAntiforgeryAttribute"; + internal const string DisableAntiforgeryAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableAntiforgeryAttributeName}"; + internal const string DisableAntiforgeryAttributeHint = $"{DisableAntiforgeryAttributeFullyQualifiedName}.gs.cs"; + + internal const string ShortCircuitAttributeName = "ShortCircuitAttribute"; + internal const string ShortCircuitAttributeFullyQualifiedName = $"{AttributesNamespace}.{ShortCircuitAttributeName}"; + internal const string ShortCircuitAttributeHint = $"{ShortCircuitAttributeFullyQualifiedName}.gs.cs"; + + internal const string DisableRequestTimeoutAttributeName = "DisableRequestTimeoutAttribute"; + internal const string DisableRequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableRequestTimeoutAttributeName}"; + internal const string DisableRequestTimeoutAttributeHint = $"{DisableRequestTimeoutAttributeFullyQualifiedName}.gs.cs"; + + internal const string DisableValidationAttributeName = "DisableValidationAttribute"; + internal const string DisableValidationAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableValidationAttributeName}"; + internal const string DisableValidationAttributeHint = $"{DisableValidationAttributeFullyQualifiedName}.gs.cs"; + + internal const string RequestTimeoutAttributeName = "RequestTimeoutAttribute"; + internal const string RequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequestTimeoutAttributeName}"; + internal const string RequestTimeoutAttributeHint = $"{RequestTimeoutAttributeFullyQualifiedName}.gs.cs"; + + internal const string OrderAttributeName = "OrderAttribute"; + internal const string OrderAttributeFullyQualifiedName = $"{AttributesNamespace}.{OrderAttributeName}"; + internal const string OrderAttributeHint = $"{OrderAttributeFullyQualifiedName}.gs.cs"; + + internal const string MapGroupAttributeName = "MapGroupAttribute"; + internal const string MapGroupAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapGroupAttributeName}"; + internal const string MapGroupAttributeHint = $"{MapGroupAttributeFullyQualifiedName}.gs.cs"; + + internal const string SummaryAttributeName = "SummaryAttribute"; + internal const string SummaryAttributeFullyQualifiedName = $"{AttributesNamespace}.{SummaryAttributeName}"; + internal const string SummaryAttributeHint = $"{SummaryAttributeFullyQualifiedName}.gs.cs"; + + internal const string AllowAnonymousAttributeName = "AllowAnonymousAttribute"; + + internal const string EndpointFilterAttributeName = "EndpointFilterAttribute"; + internal const string EndpointFilterAttributeFullyQualifiedName = $"{AttributesNamespace}.{EndpointFilterAttributeName}"; + internal const string EndpointFilterAttributeHint = $"{EndpointFilterAttributeFullyQualifiedName}.gs.cs"; + + internal const string AcceptsAttributeName = "AcceptsAttribute"; + internal const string AcceptsAttributeFullyQualifiedName = $"{AttributesNamespace}.{AcceptsAttributeName}"; + internal const string AcceptsAttributeHint = $"{AcceptsAttributeFullyQualifiedName}.gs.cs"; + + internal const string ProducesResponseAttributeName = "ProducesResponseAttribute"; + internal const string ProducesResponseAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesResponseAttributeName}"; + internal const string ProducesResponseAttributeHint = $"{ProducesResponseAttributeFullyQualifiedName}.gs.cs"; + + internal const string ProducesProblemAttributeName = "ProducesProblemAttribute"; + internal const string ProducesProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesProblemAttributeName}"; + internal const string ProducesProblemAttributeHint = $"{ProducesProblemAttributeFullyQualifiedName}.gs.cs"; + + internal const string ProducesValidationProblemAttributeName = "ProducesValidationProblemAttribute"; + internal const string ProducesValidationProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesValidationProblemAttributeName}"; + internal const string ProducesValidationProblemAttributeHint = $"{ProducesValidationProblemAttributeFullyQualifiedName}.gs.cs"; + + internal const string RoutingNamespace = $"{BaseNamespace}.Routing"; + + internal const string AddEndpointHandlersClassName = "EndpointServicesExtensions"; + internal const string AddEndpointHandlersMethodName = "AddEndpointHandlers"; + internal const string AddEndpointHandlersMethodHint = $"{RoutingNamespace}.{AddEndpointHandlersMethodName}.g.cs"; + + internal const string UseEndpointHandlersClassName = "EndpointRouteBuilderExtensions"; + internal const string UseEndpointHandlersMethodName = "MapEndpointHandlers"; + internal const string UseEndpointHandlersMethodHint = $"{RoutingNamespace}.{UseEndpointHandlersMethodName}.g.cs"; + + internal const string ConfigureMethodName = "Configure"; + internal const string AsyncSuffix = "Async"; + internal const string GlobalPrefix = "global::"; + + internal static readonly string[] AttributesNamespaceParts = AttributesNamespace.Split('.'); + internal static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"]; + internal static readonly string[] AspNetCoreMvcNamespaceParts = ["Microsoft", "AspNetCore", "Mvc"]; + internal static readonly string[] AspNetCoreAuthorizationNamespaceParts = ["Microsoft", "AspNetCore", "Authorization"]; + internal static readonly string[] AspNetCoreRoutingNamespaceParts = ["Microsoft", "AspNetCore", "Routing"]; + internal static readonly string[] ExtensionsDependencyInjectionNamespaceParts = ["Microsoft", "Extensions", "DependencyInjection"]; + internal static readonly string[] ComponentModelNamespaceParts = ["System", "ComponentModel"]; +} diff --git a/src/GeneratedEndpoints/Common/EndpointAttributeState.cs b/src/GeneratedEndpoints/Common/EndpointAttributeState.cs new file mode 100644 index 0000000..e455138 --- /dev/null +++ b/src/GeneratedEndpoints/Common/EndpointAttributeState.cs @@ -0,0 +1,34 @@ +using System.Collections.Generic; + +namespace GeneratedEndpoints.Common; + +internal struct EndpointAttributeState +{ + public EquatableImmutableArray? Tags; + public bool? RequireAuthorization; + public EquatableImmutableArray? AuthorizationPolicies; + public bool? DisableAntiforgery; + public bool? AllowAnonymous; + public bool? ExcludeFromDescription; + public List? Accepts; + public List? Produces; + public List? ProducesProblem; + public List? ProducesValidationProblem; + public bool? RequireCors; + public string? CorsPolicyName; + public EquatableImmutableArray? RequiredHosts; + public bool? RequireRateLimiting; + public string? RateLimitingPolicyName; + public List? EndpointFilters; + public HashSet? EndpointFilterSet; + public bool HasAllowAnonymousAttribute; + public bool HasRequireAuthorizationAttribute; + public bool? ShortCircuit; + public bool? DisableValidation; + public bool? DisableRequestTimeout; + public bool? WithRequestTimeout; + public string? RequestTimeoutPolicyName; + public int? Order; + public string? EndpointGroupName; + public string? Summary; +} diff --git a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs new file mode 100644 index 0000000..7ffd09a --- /dev/null +++ b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs @@ -0,0 +1,22 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct EndpointConfiguration( + RequestHandlerMetadata Metadata, + bool RequireAuthorization, + EquatableImmutableArray? AuthorizationPolicies, + bool DisableAntiforgery, + bool AllowAnonymous, + bool RequireCors, + string? CorsPolicyName, + EquatableImmutableArray? RequiredHosts, + bool RequireRateLimiting, + string? RateLimitingPolicyName, + EquatableImmutableArray? EndpointFilterTypes, + bool ShortCircuit, + bool DisableValidation, + bool DisableRequestTimeout, + bool WithRequestTimeout, + string? RequestTimeoutPolicyName, + int? Order, + string? EndpointGroupName +); diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs new file mode 100644 index 0000000..60e4e33 --- /dev/null +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -0,0 +1,531 @@ +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Runtime.CompilerServices; +using Microsoft.CodeAnalysis; +using static GeneratedEndpoints.Common.AttributeSymbolMatcher; +using static GeneratedEndpoints.Common.Constants; + +namespace GeneratedEndpoints.Common; + +internal static class EndpointConfigurationFactory +{ + private sealed class GeneratedAttributeKindCacheEntry(GeneratedAttributeKind kind) + { + public GeneratedAttributeKind Kind { get; } = kind; + } + + private static readonly ConditionalWeakTable GeneratedAttributeKindCache = new(); + + public static EndpointConfiguration Create( + ImmutableArray attributes, + string? name, + string? displayName, + string? description, + bool enforceMethodRequireAuthorizationRules + ) + { + var state = new EndpointAttributeState(); + PopulateAttributeState(attributes, ref state); + + if (enforceMethodRequireAuthorizationRules && state is { HasRequireAuthorizationAttribute: true, HasAllowAnonymousAttribute: false }) + state.AllowAnonymous = false; + + var metadata = new RequestHandlerMetadata( + name, + displayName, + state.Summary, + description, + state.Tags, + ToEquatableOrNull(state.Accepts), + ToEquatableOrNull(state.Produces), + ToEquatableOrNull(state.ProducesProblem), + ToEquatableOrNull(state.ProducesValidationProblem), + state.ExcludeFromDescription ?? false + ); + + var withRequestTimeout = state.WithRequestTimeout ?? false; + var requestTimeoutPolicyName = withRequestTimeout ? state.RequestTimeoutPolicyName : null; + + return new EndpointConfiguration( + metadata, + state.RequireAuthorization ?? false, + state.AuthorizationPolicies, + state.DisableAntiforgery ?? false, + state.AllowAnonymous ?? false, + state.RequireCors ?? false, + state.CorsPolicyName, + state.RequiredHosts, + state.RequireRateLimiting ?? false, + state.RateLimitingPolicyName, + ToEquatableOrNull(state.EndpointFilters), + state.ShortCircuit ?? false, + state.DisableValidation ?? false, + state.DisableRequestTimeout ?? false, + withRequestTimeout, + requestTimeoutPolicyName, + state.Order, + state.EndpointGroupName + ); + } + + public static GeneratedAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) + { + var definition = attributeClass.OriginalDefinition; + var cacheEntry = GeneratedAttributeKindCache.GetValue( + definition, + static def => new GeneratedAttributeKindCacheEntry(GetGeneratedAttributeKindCore(def)) + ); + + return cacheEntry.Kind; + } + + public static string? NormalizeOptionalString(string? value) + { + return string.IsNullOrWhiteSpace(value) ? null : value!.Trim(); + } + + private static void PopulateAttributeState(ImmutableArray attributes, ref EndpointAttributeState state) + { + ref var tags = ref state.Tags; + ref var requireAuthorization = ref state.RequireAuthorization; + ref var authorizationPolicies = ref state.AuthorizationPolicies; + ref var disableAntiforgery = ref state.DisableAntiforgery; + ref var allowAnonymous = ref state.AllowAnonymous; + ref var excludeFromDescription = ref state.ExcludeFromDescription; + ref var accepts = ref state.Accepts; + ref var produces = ref state.Produces; + ref var producesProblem = ref state.ProducesProblem; + ref var producesValidationProblem = ref state.ProducesValidationProblem; + ref var requireCors = ref state.RequireCors; + ref var corsPolicyName = ref state.CorsPolicyName; + ref var requiredHosts = ref state.RequiredHosts; + ref var requireRateLimiting = ref state.RequireRateLimiting; + ref var rateLimitingPolicyName = ref state.RateLimitingPolicyName; + ref var endpointFilters = ref state.EndpointFilters; + ref var endpointFilterSet = ref state.EndpointFilterSet; + ref var hasAllowAnonymousAttribute = ref state.HasAllowAnonymousAttribute; + ref var hasRequireAuthorizationAttribute = ref state.HasRequireAuthorizationAttribute; + ref var shortCircuit = ref state.ShortCircuit; + ref var disableValidation = ref state.DisableValidation; + ref var disableRequestTimeout = ref state.DisableRequestTimeout; + ref var withRequestTimeout = ref state.WithRequestTimeout; + ref var requestTimeoutPolicyName = ref state.RequestTimeoutPolicyName; + ref var order = ref state.Order; + ref var endpointGroupName = ref state.EndpointGroupName; + ref var summary = ref state.Summary; + + foreach (var attribute in attributes) + { + var attributeClass = attribute.AttributeClass; + if (attributeClass is null) + continue; + + switch (GetGeneratedAttributeKind(attributeClass)) + { + case GeneratedAttributeKind.ShortCircuit: + shortCircuit = true; + continue; + case GeneratedAttributeKind.DisableValidation: + disableValidation = true; + continue; + case GeneratedAttributeKind.DisableRequestTimeout: + disableRequestTimeout = true; + withRequestTimeout = false; + requestTimeoutPolicyName = null; + continue; + case GeneratedAttributeKind.RequestTimeout: + { + disableRequestTimeout = false; + withRequestTimeout = true; + + string? policyName = null; + if (attribute.ConstructorArguments.Length > 0) + policyName = attribute.ConstructorArguments[0].Value as string; + + policyName ??= GetNamedStringValue(attribute, PolicyNameAttributeNamedParameter); + requestTimeoutPolicyName = NormalizeOptionalString(policyName); + continue; + } + case GeneratedAttributeKind.Order: + if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int orderValue) + order = orderValue; + continue; + case GeneratedAttributeKind.MapGroup: + { + var groupName = GetNamedStringValue(attribute, NameAttributeNamedParameter); + if (!string.IsNullOrEmpty(groupName)) + endpointGroupName = groupName; + continue; + } + case GeneratedAttributeKind.Summary: + if (attribute.ConstructorArguments.Length > 0) + { + var summaryValue = NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string); + if (!string.IsNullOrEmpty(summaryValue)) + summary = summaryValue; + } + continue; + case GeneratedAttributeKind.Accepts: + TryAddAcceptsMetadata(attribute, attributeClass, ref accepts); + continue; + case GeneratedAttributeKind.ProducesResponse: + TryAddProducesMetadata(attribute, attributeClass, ref produces); + continue; + case GeneratedAttributeKind.RequireAuthorization: + requireAuthorization = true; + hasRequireAuthorizationAttribute = true; + if (attribute.ConstructorArguments.Length == 1) + { + var arg = attribute.ConstructorArguments[0]; + MergeInto(ref authorizationPolicies, arg.Values); + } + + continue; + case GeneratedAttributeKind.RequireCors: + requireCors = true; + corsPolicyName = attribute.ConstructorArguments.Length > 0 + ? NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string) + : null; + continue; + case GeneratedAttributeKind.RequireHost: + if (attribute.ConstructorArguments.Length == 1) + { + var arg = attribute.ConstructorArguments[0]; + if (arg is { Kind: TypedConstantKind.Array, Values.Length: > 0 }) + MergeInto(ref requiredHosts, arg.Values); + else if (arg.Value is string singleHost && !string.IsNullOrWhiteSpace(singleHost)) + MergeInto(ref requiredHosts, [singleHost.Trim()]); + } + + continue; + case GeneratedAttributeKind.RequireRateLimiting: + { + var policyName = attribute.ConstructorArguments.Length > 0 + ? NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string) + : null; + + if (!string.IsNullOrEmpty(policyName)) + { + requireRateLimiting = true; + rateLimitingPolicyName = policyName; + } + + continue; + } + case GeneratedAttributeKind.EndpointFilter: + TryAddEndpointFilter(attribute, attributeClass, ref endpointFilters, ref endpointFilterSet); + continue; + case GeneratedAttributeKind.DisableAntiforgery: + disableAntiforgery = true; + continue; + case GeneratedAttributeKind.ProducesProblem: + { + var statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesProblemStatusCode + ? producesProblemStatusCode + : 500; + var contentType = attribute.ConstructorArguments.Length > 1 + ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) + : null; + var additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; + + var producesProblemList = producesProblem ??= []; + producesProblemList.Add(new ProducesProblemMetadata(statusCode, contentType, additionalContentTypes)); + continue; + } + case GeneratedAttributeKind.ProducesValidationProblem: + { + var statusCode = + attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesValidationProblemStatusCode + ? producesValidationProblemStatusCode + : 400; + var contentType = attribute.ConstructorArguments.Length > 1 + ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) + : null; + var additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; + + var producesValidationProblemList = producesValidationProblem ??= []; + producesValidationProblemList.Add(new ProducesValidationProblemMetadata(statusCode, contentType, additionalContentTypes)); + continue; + } + } + + if (IsAttribute(attributeClass, AllowAnonymousAttributeName, AspNetCoreAuthorizationNamespaceParts)) + { + allowAnonymous = true; + hasAllowAnonymousAttribute = true; + continue; + } + + if (IsAttribute(attributeClass, "TagsAttribute", AspNetCoreHttpNamespaceParts)) + { + if (attribute.ConstructorArguments.Length > 0) + { + var arg = attribute.ConstructorArguments[0]; + MergeInto(ref tags, arg.Values); + } + + continue; + } + + if (IsAttribute(attributeClass, "ExcludeFromDescriptionAttribute", AspNetCoreRoutingNamespaceParts)) + excludeFromDescription = true; + } + } + + private static void MergeInto(ref EquatableImmutableArray? target, IEnumerable values) + { + var merged = MergeUnion(target, values); + target = merged.Count > 0 ? merged : null; + } + + private static void MergeInto(ref EquatableImmutableArray? target, ImmutableArray values) + { + if (values.IsDefaultOrEmpty) + return; + + List? normalized = null; + foreach (var value in values) + { + if (value.Value is not string stringValue) + continue; + + var trimmed = NormalizeOptionalString(stringValue); + if (trimmed is not { Length: > 0 }) + continue; + + normalized ??= new List(values.Length); + normalized.Add(trimmed); + } + + if (normalized is { Count: > 0 }) + MergeInto(ref target, normalized); + } + + internal static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, IEnumerable values) + { + List? list = null; + HashSet? seen = null; + + if (existing is { Count: > 0 }) + { + var count = existing.Value.Count; + list = new List(count + 4); + list.AddRange(existing.Value); + seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); + } + + foreach (var value in values) + { + var normalized = NormalizeOptionalString(value); + if (normalized is not { Length: > 0 }) + continue; + + seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); + if (!seen.Add(normalized)) + continue; + + list ??= []; + list.Add(normalized); + } + + return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; + } + + private static EquatableImmutableArray? ToEquatableOrNull(List? values) + { + return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null; + } + + private static string NormalizeRequiredContentType(string? contentType, string defaultValue) + { + return string.IsNullOrWhiteSpace(contentType) ? defaultValue : contentType!.Trim(); + } + + private static string? NormalizeOptionalContentType(string? contentType) + { + return string.IsNullOrWhiteSpace(contentType) ? null : contentType!.Trim(); + } + + private static EquatableImmutableArray? GetStringArrayValues(TypedConstant typedConstant) + { + if (typedConstant.Kind != TypedConstantKind.Array || typedConstant.Values.IsDefaultOrEmpty) + return null; + + var builder = ImmutableArray.CreateBuilder(typedConstant.Values.Length); + foreach (var value in typedConstant.Values) + { + if (value.Value is string s && !string.IsNullOrWhiteSpace(s)) + builder.Add(s.Trim()); + } + + return builder.Count > 0 ? builder.ToEquatableImmutable() : null; + } + + private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? accepts) + { + string? requestType; + string contentType; + EquatableImmutableArray? additionalContentTypes; + var isOptional = GetNamedBoolValue(attribute, IsOptionalAttributeNamedParameter); + + if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) + { + requestType = attributeClass.TypeArguments[0] + .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + contentType = attribute.ConstructorArguments.Length > 0 + ? NormalizeRequiredContentType(attribute.ConstructorArguments[0].Value as string, "application/json") + : "application/json"; + additionalContentTypes = attribute.ConstructorArguments.Length > 1 ? GetStringArrayValues(attribute.ConstructorArguments[1]) : null; + } + else if (GetNamedTypeSymbol(attribute, RequestTypeAttributeNamedParameter) is { } requestTypeSymbol) + { + requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + contentType = attribute.ConstructorArguments.Length > 0 + ? NormalizeRequiredContentType(attribute.ConstructorArguments[0].Value as string, "application/json") + : "application/json"; + additionalContentTypes = attribute.ConstructorArguments.Length > 1 ? GetStringArrayValues(attribute.ConstructorArguments[1]) : null; + } + else + { + return; + } + + var acceptsList = accepts ??= []; + acceptsList.Add(new AcceptsMetadata(requestType, contentType, additionalContentTypes, isOptional)); + } + + private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? produces) + { + string? responseType; + int statusCode; + string? contentType; + EquatableImmutableArray? additionalContentTypes; + + if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) + { + responseType = attributeClass.TypeArguments[0] + .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode + ? producesStatusCode + : 200; + contentType = attribute.ConstructorArguments.Length > 1 ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) : null; + additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; + } + else if (GetNamedTypeSymbol(attribute, ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol) + { + responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode + ? producesStatusCode + : 200; + contentType = attribute.ConstructorArguments.Length > 1 ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) : null; + additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; + } + else + { + return; + } + + var producesList = produces ??= []; + producesList.Add(new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes)); + } + + private static void TryAddEndpointFilter( + AttributeData attribute, + INamedTypeSymbol attributeClass, + ref List? endpointFilters, + ref HashSet? endpointFilterSet) + { + if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) + { + TryAddEndpointFilterType(attributeClass.TypeArguments[0], ref endpointFilters, ref endpointFilterSet); + return; + } + + if (attribute.ConstructorArguments.Length == 0) + return; + + if (attribute.ConstructorArguments[0].Value is ITypeSymbol filterTypeSymbol) + TryAddEndpointFilterType(filterTypeSymbol, ref endpointFilters, ref endpointFilterSet); + } + + private static void TryAddEndpointFilterType( + ITypeSymbol? typeSymbol, + ref List? endpointFilters, + ref HashSet? endpointFilterSet) + { + if (typeSymbol is null or ITypeParameterSymbol or IErrorTypeSymbol) + return; + + var displayString = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + if (string.IsNullOrWhiteSpace(displayString)) + return; + + endpointFilterSet ??= new HashSet(StringComparer.Ordinal); + if (!endpointFilterSet.Add(displayString)) + return; + + endpointFilters ??= []; + endpointFilters.Add(displayString); + } + + private static ITypeSymbol? GetNamedTypeSymbol(AttributeData attribute, string namedParameter) + { + foreach (var namedArg in attribute.NamedArguments) + { + if (namedArg.Key == namedParameter && namedArg.Value.Value is ITypeSymbol typeSymbol) + return typeSymbol; + } + + return null; + } + + private static bool GetNamedBoolValue(AttributeData attribute, string namedParameter, bool defaultValue = false) + { + foreach (var namedArg in attribute.NamedArguments) + { + if (namedArg.Key == namedParameter && namedArg.Value.Value is bool boolValue) + return boolValue; + } + + return defaultValue; + } + + private static string? GetNamedStringValue(AttributeData attribute, string namedParameter) + { + foreach (var namedArg in attribute.NamedArguments) + { + if (namedArg.Key == namedParameter && namedArg.Value.Value is string stringValue) + return NormalizeOptionalString(stringValue); + } + + return null; + } + + private static GeneratedAttributeKind GetGeneratedAttributeKindCore(INamedTypeSymbol definition) + { + if (!IsInNamespace(definition.ContainingNamespace, AttributesNamespaceParts)) + return GeneratedAttributeKind.None; + + return definition.Name switch + { + ShortCircuitAttributeName => GeneratedAttributeKind.ShortCircuit, + DisableValidationAttributeName => GeneratedAttributeKind.DisableValidation, + DisableRequestTimeoutAttributeName => GeneratedAttributeKind.DisableRequestTimeout, + RequestTimeoutAttributeName => GeneratedAttributeKind.RequestTimeout, + OrderAttributeName => GeneratedAttributeKind.Order, + MapGroupAttributeName => GeneratedAttributeKind.MapGroup, + SummaryAttributeName => GeneratedAttributeKind.Summary, + AcceptsAttributeName => GeneratedAttributeKind.Accepts, + ProducesResponseAttributeName => GeneratedAttributeKind.ProducesResponse, + RequireAuthorizationAttributeName => GeneratedAttributeKind.RequireAuthorization, + RequireCorsAttributeName => GeneratedAttributeKind.RequireCors, + RequireHostAttributeName => GeneratedAttributeKind.RequireHost, + RequireRateLimitingAttributeName => GeneratedAttributeKind.RequireRateLimiting, + EndpointFilterAttributeName => GeneratedAttributeKind.EndpointFilter, + DisableAntiforgeryAttributeName => GeneratedAttributeKind.DisableAntiforgery, + ProducesProblemAttributeName => GeneratedAttributeKind.ProducesProblem, + ProducesValidationProblemAttributeName => GeneratedAttributeKind.ProducesValidationProblem, + _ => GeneratedAttributeKind.None, + }; + } +} diff --git a/src/GeneratedEndpoints/Common/GeneratedAttributeKind.cs b/src/GeneratedEndpoints/Common/GeneratedAttributeKind.cs new file mode 100644 index 0000000..b9af832 --- /dev/null +++ b/src/GeneratedEndpoints/Common/GeneratedAttributeKind.cs @@ -0,0 +1,23 @@ +namespace GeneratedEndpoints.Common; + +internal enum GeneratedAttributeKind +{ + None = 0, + ShortCircuit, + DisableValidation, + DisableRequestTimeout, + RequestTimeout, + Order, + MapGroup, + Summary, + Accepts, + ProducesResponse, + RequireAuthorization, + RequireCors, + RequireHost, + RequireRateLimiting, + EndpointFilter, + DisableAntiforgery, + ProducesProblem, + ProducesValidationProblem, +} diff --git a/src/GeneratedEndpoints/Common/HttpAttributeDefinition.cs b/src/GeneratedEndpoints/Common/HttpAttributeDefinition.cs new file mode 100644 index 0000000..e740ab6 --- /dev/null +++ b/src/GeneratedEndpoints/Common/HttpAttributeDefinition.cs @@ -0,0 +1,11 @@ +using Microsoft.CodeAnalysis.Text; + +namespace GeneratedEndpoints.Common; + +internal readonly record struct HttpAttributeDefinition( + string Name, + string FullyQualifiedName, + string Hint, + string Verb, + SourceText SourceText +); diff --git a/src/GeneratedEndpoints/Common/Parameter.cs b/src/GeneratedEndpoints/Common/Parameter.cs new file mode 100644 index 0000000..6e4e53f --- /dev/null +++ b/src/GeneratedEndpoints/Common/Parameter.cs @@ -0,0 +1,3 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct Parameter(string Name, string Type, string BindingPrefix); diff --git a/src/GeneratedEndpoints/Common/ProducesMetadata.cs b/src/GeneratedEndpoints/Common/ProducesMetadata.cs new file mode 100644 index 0000000..e7f223c --- /dev/null +++ b/src/GeneratedEndpoints/Common/ProducesMetadata.cs @@ -0,0 +1,8 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct ProducesMetadata( + string ResponseType, + int StatusCode, + string? ContentType, + EquatableImmutableArray? AdditionalContentTypes +); diff --git a/src/GeneratedEndpoints/Common/ProducesProblemMetadata.cs b/src/GeneratedEndpoints/Common/ProducesProblemMetadata.cs new file mode 100644 index 0000000..04baeaa --- /dev/null +++ b/src/GeneratedEndpoints/Common/ProducesProblemMetadata.cs @@ -0,0 +1,7 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct ProducesProblemMetadata( + int StatusCode, + string? ContentType, + EquatableImmutableArray? AdditionalContentTypes +); diff --git a/src/GeneratedEndpoints/Common/ProducesValidationProblemMetadata.cs b/src/GeneratedEndpoints/Common/ProducesValidationProblemMetadata.cs new file mode 100644 index 0000000..46bca0c --- /dev/null +++ b/src/GeneratedEndpoints/Common/ProducesValidationProblemMetadata.cs @@ -0,0 +1,7 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct ProducesValidationProblemMetadata( + int StatusCode, + string? ContentType, + EquatableImmutableArray? AdditionalContentTypes +); diff --git a/src/GeneratedEndpoints/Common/RequestHandler.cs b/src/GeneratedEndpoints/Common/RequestHandler.cs new file mode 100644 index 0000000..eeb6736 --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandler.cs @@ -0,0 +1,9 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct RequestHandler( + RequestHandlerClass Class, + RequestHandlerMethod Method, + string HttpMethod, + string Pattern, + EndpointConfiguration Configuration +); diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClass.cs b/src/GeneratedEndpoints/Common/RequestHandlerClass.cs new file mode 100644 index 0000000..baae895 --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerClass.cs @@ -0,0 +1,11 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct RequestHandlerClass( + string Name, + bool IsStatic, + bool HasConfigureMethod, + bool ConfigureMethodAcceptsServiceProvider, + string? MapGroupPattern, + string? MapGroupBuilderIdentifier, + EndpointConfiguration Configuration +); diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs new file mode 100644 index 0000000..28dadc2 --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs @@ -0,0 +1,55 @@ +using System.Threading; +using GeneratedEndpoints; +using Microsoft.CodeAnalysis; + +namespace GeneratedEndpoints.Common; + +internal sealed class RequestHandlerClassCacheEntry +{ + private readonly object _lock = new(); + private RequestHandlerClass _value; + private bool _initialized; + + public RequestHandlerClass GetOrCreate( + INamedTypeSymbol classSymbol, + CompilationTypeCache compilationCache, + CancellationToken cancellationToken + ) + { + if (_initialized) + return _value; + + lock (_lock) + { + if (_initialized) + return _value; + + cancellationToken.ThrowIfCancellationRequested(); + + var name = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var isStatic = classSymbol.IsStatic; + var configureMethodDetails = MinimalApiGenerator.GetConfigureMethodDetails( + classSymbol, + compilationCache.EndpointConventionBuilderSymbol, + compilationCache.ServiceProviderSymbol, + cancellationToken + ); + + var mapGroupPattern = MinimalApiGenerator.GetMapGroupPattern(classSymbol); + var mapGroupIdentifier = mapGroupPattern is null ? null : MinimalApiGenerator.GetMapGroupIdentifier(name); + var classConfiguration = EndpointConfigurationFactory.Create(classSymbol.GetAttributes(), null, null, null, false); + + _value = new RequestHandlerClass( + name, + isStatic, + configureMethodDetails.HasConfigureMethod, + configureMethodDetails.ConfigureMethodAcceptsServiceProvider, + mapGroupPattern, + mapGroupIdentifier, + classConfiguration + ); + _initialized = true; + return _value; + } + } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerComparer.cs b/src/GeneratedEndpoints/Common/RequestHandlerComparer.cs new file mode 100644 index 0000000..f6b78eb --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerComparer.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; + +namespace GeneratedEndpoints.Common; + +internal sealed class RequestHandlerComparer : IComparer +{ + public static RequestHandlerComparer Instance { get; } = new(); + + public int Compare(RequestHandler x, RequestHandler y) + { + var comparison = string.Compare(x.Class.Name, y.Class.Name, StringComparison.Ordinal); + if (comparison != 0) + return comparison; + + comparison = string.Compare(x.Method.Name, y.Method.Name, StringComparison.Ordinal); + if (comparison != 0) + return comparison; + + comparison = string.Compare(x.HttpMethod, y.HttpMethod, StringComparison.Ordinal); + if (comparison != 0) + return comparison; + + return string.Compare(x.Pattern, y.Pattern, StringComparison.Ordinal); + } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerMetadata.cs b/src/GeneratedEndpoints/Common/RequestHandlerMetadata.cs new file mode 100644 index 0000000..8f76cf8 --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerMetadata.cs @@ -0,0 +1,14 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct RequestHandlerMetadata( + string? Name, + string? DisplayName, + string? Summary, + string? Description, + EquatableImmutableArray? Tags, + EquatableImmutableArray? Accepts, + EquatableImmutableArray? Produces, + EquatableImmutableArray? ProducesProblem, + EquatableImmutableArray? ProducesValidationProblem, + bool ExcludeFromDescription +); diff --git a/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs new file mode 100644 index 0000000..d6efe14 --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs @@ -0,0 +1,8 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct RequestHandlerMethod( + string Name, + bool IsStatic, + bool IsAwaitable, + EquatableImmutableArray Parameters +); diff --git a/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs b/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs new file mode 100644 index 0000000..1aaeb2f --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs @@ -0,0 +1,131 @@ +using System.Collections.Immutable; +using System.Threading; +using Microsoft.CodeAnalysis; +using static GeneratedEndpoints.Common.AttributeSymbolMatcher; +using static GeneratedEndpoints.Common.Constants; +using static GeneratedEndpoints.MinimalApiGenerator; + +namespace GeneratedEndpoints.Common; + +internal static class RequestHandlerParameterHelper +{ + public static EquatableImmutableArray Build(IMethodSymbol methodSymbol, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + var methodParameters = ImmutableArray.CreateBuilder(methodSymbol.Parameters.Length); + foreach (var parameter in methodSymbol.Parameters) + { + cancellationToken.ThrowIfCancellationRequested(); + + var source = BindingSource.None; + TypedConstant? typedKey = null; + string? bindingName = null; + + foreach (var attribute in parameter.GetAttributes()) + { + var attributeClass = attribute.AttributeClass; + if (attributeClass is null) + continue; + + var attributeSource = GetBindingSourceFromAttributeClass(attributeClass); + if (attributeSource == BindingSource.None) + continue; + + source = attributeSource; + switch (attributeSource) + { + case BindingSource.FromRoute: + case BindingSource.FromQuery: + case BindingSource.FromHeader: + case BindingSource.FromForm: + bindingName = GetBindingAttributeName(attribute) ?? bindingName; + break; + case BindingSource.FromKeyedServices: + typedKey = attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0] : null; + break; + } + } + + var parameterName = parameter.Name; + var parameterType = parameter.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var key = typedKey.HasValue ? ConstLiteral(typedKey.Value) : null; + var bindingPrefix = GetBindingSourceAttribute(source, key, bindingName); + methodParameters.Add(new Parameter(parameterName, parameterType, bindingPrefix)); + } + + return methodParameters.ToEquatableImmutable(); + } + + private static string? GetBindingAttributeName(AttributeData attribute) + { + foreach (var namedArg in attribute.NamedArguments) + { + if (string.Equals(namedArg.Key, NameAttributeNamedParameter, StringComparison.Ordinal) && namedArg.Value.Value is string namedValue) + { + var normalized = NormalizeBindingName(namedValue); + if (normalized is not null) + return normalized; + } + } + + if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is string constructorName) + return NormalizeBindingName(constructorName); + + return null; + } + + private static string? NormalizeBindingName(string? value) + { + if (string.IsNullOrWhiteSpace(value)) + return null; + + var trimmed = value!.Trim(); + return trimmed.Length > 0 ? trimmed : null; + } + + private static BindingSource GetBindingSourceFromAttributeClass(INamedTypeSymbol attributeClass) + { + var definition = attributeClass.OriginalDefinition; + var namespaceSymbol = definition.ContainingNamespace; + + return definition.Name switch + { + "FromRouteAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromRoute, + "FromQueryAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromQuery, + "FromHeaderAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromHeader, + "FromBodyAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromBody, + "FromFormAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromForm, + "FromServicesAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromServices, + "FromKeyedServicesAttribute" when IsInNamespace(namespaceSymbol, ExtensionsDependencyInjectionNamespaceParts) + => BindingSource.FromKeyedServices, + "AsParametersAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreHttpNamespaceParts) => BindingSource.AsParameters, + _ => BindingSource.None, + }; + } + + private static string GetBindingSourceAttribute(BindingSource source, string? key, string? bindingName) + { + return source switch + { + BindingSource.None => string.Empty, + BindingSource.FromRoute => FormatBindingAttribute("FromRoute", bindingName), + BindingSource.FromQuery => FormatBindingAttribute("FromQuery", bindingName), + BindingSource.FromHeader => FormatBindingAttribute("FromHeader", bindingName), + BindingSource.FromBody => FormatBindingAttribute("FromBody", bindingName), + BindingSource.FromForm => FormatBindingAttribute("FromForm", bindingName), + BindingSource.FromServices => "[FromServices] ", + BindingSource.FromKeyedServices => $"[FromKeyedServices({key})] ", + BindingSource.AsParameters => "[AsParameters] ", + _ => string.Empty, + }; + } + + private static string FormatBindingAttribute(string attributeName, string? bindingName) + { + if (bindingName is null) + return $"[{attributeName}] "; + + return $"[{attributeName}(Name = {StringLiteral(bindingName)})] "; + } +} diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.Types.cs b/src/GeneratedEndpoints/MinimalApiGenerator.Types.cs deleted file mode 100644 index 8ec7a24..0000000 --- a/src/GeneratedEndpoints/MinimalApiGenerator.Types.cs +++ /dev/null @@ -1,234 +0,0 @@ -using GeneratedEndpoints.Common; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.Text; - -namespace GeneratedEndpoints; - -public sealed partial class MinimalApiGenerator -{ - private readonly record struct HttpAttributeDefinition( - string Name, - string FullyQualifiedName, - string Hint, - string Verb, - SourceText SourceText - ); - - private readonly record struct RequestHandler( - RequestHandlerClass Class, - RequestHandlerMethod Method, - string HttpMethod, - string Pattern, - EndpointConfiguration Configuration - ); - - private readonly record struct RequestHandlerClass( - string Name, - bool IsStatic, - bool HasConfigureMethod, - bool ConfigureMethodAcceptsServiceProvider, - string? MapGroupPattern, - string? MapGroupBuilderIdentifier, - EndpointConfiguration Configuration - ); - - private readonly record struct EndpointConfiguration( - RequestHandlerMetadata Metadata, - bool RequireAuthorization, - EquatableImmutableArray? AuthorizationPolicies, - bool DisableAntiforgery, - bool AllowAnonymous, - bool RequireCors, - string? CorsPolicyName, - EquatableImmutableArray? RequiredHosts, - bool RequireRateLimiting, - string? RateLimitingPolicyName, - EquatableImmutableArray? EndpointFilterTypes, - bool ShortCircuit, - bool DisableValidation, - bool DisableRequestTimeout, - bool WithRequestTimeout, - string? RequestTimeoutPolicyName, - int? Order, - string? EndpointGroupName - ); - - private readonly record struct RequestHandlerMethod(string Name, bool IsStatic, bool IsAwaitable, EquatableImmutableArray Parameters); - - private readonly record struct RequestHandlerMetadata( - string? Name, - string? DisplayName, - string? Summary, - string? Description, - EquatableImmutableArray? Tags, - EquatableImmutableArray? Accepts, - EquatableImmutableArray? Produces, - EquatableImmutableArray? ProducesProblem, - EquatableImmutableArray? ProducesValidationProblem, - bool ExcludeFromDescription - ); - - private readonly record struct AcceptsMetadata( - string RequestType, - string ContentType, - EquatableImmutableArray? AdditionalContentTypes, - bool IsOptional - ); - - private readonly record struct ProducesMetadata( - string ResponseType, - int StatusCode, - string? ContentType, - EquatableImmutableArray? AdditionalContentTypes - ); - - private readonly record struct ProducesProblemMetadata(int StatusCode, string? ContentType, EquatableImmutableArray? AdditionalContentTypes); - - private readonly record struct ProducesValidationProblemMetadata( - int StatusCode, - string? ContentType, - EquatableImmutableArray? AdditionalContentTypes - ); - - private readonly record struct Parameter(string Name, string Type, string BindingPrefix); - - private readonly record struct ConfigureMethodDetails(bool HasConfigureMethod, bool ConfigureMethodAcceptsServiceProvider); - - private struct EndpointAttributeState - { - public EquatableImmutableArray? Tags; - public bool? RequireAuthorization; - public EquatableImmutableArray? AuthorizationPolicies; - public bool? DisableAntiforgery; - public bool? AllowAnonymous; - public bool? ExcludeFromDescription; - public List? Accepts; - public List? Produces; - public List? ProducesProblem; - public List? ProducesValidationProblem; - public bool? RequireCors; - public string? CorsPolicyName; - public EquatableImmutableArray? RequiredHosts; - public bool? RequireRateLimiting; - public string? RateLimitingPolicyName; - public List? EndpointFilters; - public HashSet? EndpointFilterSet; - public bool HasAllowAnonymousAttribute; - public bool HasRequireAuthorizationAttribute; - public bool? ShortCircuit; - public bool? DisableValidation; - public bool? DisableRequestTimeout; - public bool? WithRequestTimeout; - public string? RequestTimeoutPolicyName; - public int? Order; - public string? EndpointGroupName; - public string? Summary; - } - - private enum GeneratedAttributeKind - { - None = 0, - ShortCircuit, - DisableValidation, - DisableRequestTimeout, - RequestTimeout, - Order, - MapGroup, - Summary, - Accepts, - ProducesResponse, - RequireAuthorization, - RequireCors, - RequireHost, - RequireRateLimiting, - EndpointFilter, - DisableAntiforgery, - ProducesProblem, - ProducesValidationProblem, - } - - private enum BindingSource - { - None = 0, - FromRoute = 1, - FromQuery = 2, - FromHeader = 3, - FromBody = 4, - FromForm = 5, - FromServices = 6, - FromKeyedServices = 7, - AsParameters = 8, - } - - private sealed class RequestHandlerComparer : IComparer - { - public static RequestHandlerComparer Instance { get; } = new(); - - public int Compare(RequestHandler x, RequestHandler y) - { - var comparison = string.Compare(x.Class.Name, y.Class.Name, StringComparison.Ordinal); - if (comparison != 0) - return comparison; - - comparison = string.Compare(x.Method.Name, y.Method.Name, StringComparison.Ordinal); - if (comparison != 0) - return comparison; - - comparison = string.Compare(x.HttpMethod, y.HttpMethod, StringComparison.Ordinal); - if (comparison != 0) - return comparison; - - return string.Compare(x.Pattern, y.Pattern, StringComparison.Ordinal); - } - } - - private sealed class CompilationTypeCache(Compilation compilation) - { - public INamedTypeSymbol? EndpointConventionBuilderSymbol { get; } = - compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Builder.IEndpointConventionBuilder"); - - public INamedTypeSymbol? ServiceProviderSymbol { get; } = compilation.GetTypeByMetadataName("System.IServiceProvider"); - } - - private sealed class RequestHandlerClassCacheEntry - { - private readonly object _lock = new(); - private RequestHandlerClass _value; - private bool _initialized; - - public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, CompilationTypeCache compilationCache, CancellationToken cancellationToken) - { - if (_initialized) - return _value; - - lock (_lock) - { - if (_initialized) - return _value; - - cancellationToken.ThrowIfCancellationRequested(); - - var name = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var isStatic = classSymbol.IsStatic; - var configureMethodDetails = GetConfigureMethodDetails(classSymbol, compilationCache.EndpointConventionBuilderSymbol, - compilationCache.ServiceProviderSymbol, cancellationToken - ); - - var mapGroupPattern = GetMapGroupPattern(classSymbol); - var mapGroupIdentifier = mapGroupPattern is null ? null : GetMapGroupIdentifier(name); - var classConfiguration = GetEndpointConfiguration(classSymbol.GetAttributes(), null, null, null, false); - - _value = new RequestHandlerClass(name, isStatic, configureMethodDetails.HasConfigureMethod, - configureMethodDetails.ConfigureMethodAcceptsServiceProvider, mapGroupPattern, mapGroupIdentifier, classConfiguration - ); - _initialized = true; - return _value; - } - } - } - - private sealed class GeneratedAttributeKindCacheEntry(GeneratedAttributeKind kind) - { - public GeneratedAttributeKind Kind { get; } = kind; - } -} diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index c5ebad8..f63729c 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -3,17 +3,22 @@ using System.ComponentModel; using System.Diagnostics.CodeAnalysis; using System.Globalization; +using System.Runtime.CompilerServices; using System.Text; using GeneratedEndpoints.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Text; +using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints; [Generator] public sealed partial class MinimalApiGenerator : IIncrementalGenerator { + private static readonly ConditionalWeakTable CompilationTypeCaches = new(); + private static readonly ConditionalWeakTable RequestHandlerClassCache = new(); + public void Initialize(IncrementalGeneratorInitializationContext context) { context.RegisterPostInitializationOutput(RegisterAttributes); @@ -35,15 +40,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterSourceOutput(requestHandlers, GenerateSource); } - private static HttpAttributeDefinition CreateHttpAttributeDefinition(string attributeName, string verb, bool allowOptionalPattern = false) - { - var fullyQualifiedName = $"{AttributesNamespace}.{attributeName}"; - var hint = $"{fullyQualifiedName}.gs.cs"; - var summaryVerb = verb == FallbackHttpMethod ? "fallback" : verb; - var source = GenerateHttpAttributeSource(AttributesNamespace, attributeName, summaryVerb, allowOptionalPattern); - return new HttpAttributeDefinition(attributeName, fullyQualifiedName, hint, verb, SourceText.From(source, Encoding.UTF8)); - } - private static IncrementalValueProvider> CombineRequestHandlers( ImmutableArray>> handlerProviders ) @@ -85,41 +81,6 @@ private static void RegisterAttributes(IncrementalGeneratorPostInitializationCon context.AddSource(ProducesValidationProblemAttributeHint, ProducesValidationProblemAttributeSourceText); } - private static string GenerateHttpAttributeSource(string attributesNamespace, string attributeName, string summaryVerb, bool allowOptionalPattern = false) - { - return $$""" - {{FileHeader}} - - namespace {{attributesNamespace}}; - - /// - /// Identifies a method as an HTTP {{summaryVerb}} minimal API endpoint with the specified route pattern. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{attributeName}} : global::System.Attribute - { - /// - /// Gets the route pattern for the endpoint. - /// - public string{{(allowOptionalPattern ? "?" : "")}} Pattern { get; } - - /// - /// Gets or sets the endpoint name. - /// - public string? Name { get; init; } - - /// - /// Initializes a new instance of the class. - /// - /// The route pattern for the endpoint. - public {{attributeName}}([global::System.Diagnostics.CodeAnalysis.StringSyntax("Route")] string{{(allowOptionalPattern ? "?" : "")}} pattern{{(allowOptionalPattern ? " = null" : "")}}) - { - Pattern = pattern; - } - } - """; - } - private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -147,7 +108,7 @@ private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToke name ??= RemoveAsyncSuffix(requestHandlerMethod.Name); - var methodConfiguration = GetEndpointConfiguration(requestHandlerMethodSymbol.GetAttributes(), name, displayName, description, true); + var methodConfiguration = EndpointConfigurationFactory.Create(requestHandlerMethodSymbol.GetAttributes(), name, displayName, description, true); var requestHandler = new RequestHandler(requestHandlerClass.Value, requestHandlerMethod, httpMethod, pattern, methodConfiguration); @@ -200,289 +161,24 @@ private static (string? DisplayName, string? Description) GetDisplayAndDescripti if (attributeClass is null) continue; - if (IsAttribute(attributeClass, nameof(DisplayNameAttribute), ComponentModelNamespaceParts)) + if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DisplayNameAttribute), ComponentModelNamespaceParts)) { - displayName = NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : null); + displayName = EndpointConfigurationFactory.NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : null); continue; } - if (IsAttribute(attributeClass, nameof(DescriptionAttribute), ComponentModelNamespaceParts)) - description = NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : null); + if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DescriptionAttribute), ComponentModelNamespaceParts)) + description = EndpointConfigurationFactory.NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : null); } return (displayName, description); } - private static EndpointConfiguration GetEndpointConfiguration( - ImmutableArray attributes, - string? name, - string? displayName, - string? description, - bool enforceMethodRequireAuthorizationRules - ) - { - var state = new EndpointAttributeState(); - - GetAdditionalRequestHandlerAttributeValues(attributes, ref state); - - if (enforceMethodRequireAuthorizationRules && state is { HasRequireAuthorizationAttribute: true, HasAllowAnonymousAttribute: false }) - state.AllowAnonymous = false; - - var metadata = new RequestHandlerMetadata(name, displayName, state.Summary, description, state.Tags, ToEquatableOrNull(state.Accepts), - ToEquatableOrNull(state.Produces), ToEquatableOrNull(state.ProducesProblem), ToEquatableOrNull(state.ProducesValidationProblem), - state.ExcludeFromDescription ?? false - ); - - var withRequestTimeout = state.WithRequestTimeout ?? false; - var requestTimeoutPolicyName = withRequestTimeout ? state.RequestTimeoutPolicyName : null; - - return new EndpointConfiguration(metadata, state.RequireAuthorization ?? false, state.AuthorizationPolicies, state.DisableAntiforgery ?? false, - state.AllowAnonymous ?? false, state.RequireCors ?? false, state.CorsPolicyName, state.RequiredHosts, state.RequireRateLimiting ?? false, - state.RateLimitingPolicyName, ToEquatableOrNull(state.EndpointFilters), state.ShortCircuit ?? false, state.DisableValidation ?? false, - state.DisableRequestTimeout ?? false, withRequestTimeout, requestTimeoutPolicyName, state.Order, state.EndpointGroupName - ); - } - - private static void GetAdditionalRequestHandlerAttributeValues(ImmutableArray attributes, ref EndpointAttributeState state) - { - ref var tags = ref state.Tags; - ref var requireAuthorization = ref state.RequireAuthorization; - ref var authorizationPolicies = ref state.AuthorizationPolicies; - ref var disableAntiforgery = ref state.DisableAntiforgery; - ref var allowAnonymous = ref state.AllowAnonymous; - ref var excludeFromDescription = ref state.ExcludeFromDescription; - ref var accepts = ref state.Accepts; - ref var produces = ref state.Produces; - ref var producesProblem = ref state.ProducesProblem; - ref var producesValidationProblem = ref state.ProducesValidationProblem; - ref var requireCors = ref state.RequireCors; - ref var corsPolicyName = ref state.CorsPolicyName; - ref var requiredHosts = ref state.RequiredHosts; - ref var requireRateLimiting = ref state.RequireRateLimiting; - ref var rateLimitingPolicyName = ref state.RateLimitingPolicyName; - ref var endpointFilters = ref state.EndpointFilters; - ref var endpointFilterSet = ref state.EndpointFilterSet; - ref var hasAllowAnonymousAttribute = ref state.HasAllowAnonymousAttribute; - ref var hasRequireAuthorizationAttribute = ref state.HasRequireAuthorizationAttribute; - ref var shortCircuit = ref state.ShortCircuit; - ref var disableValidation = ref state.DisableValidation; - ref var disableRequestTimeout = ref state.DisableRequestTimeout; - ref var withRequestTimeout = ref state.WithRequestTimeout; - ref var requestTimeoutPolicyName = ref state.RequestTimeoutPolicyName; - ref var order = ref state.Order; - ref var endpointGroupName = ref state.EndpointGroupName; - ref var summary = ref state.Summary; - - foreach (var attribute in attributes) - { - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - switch (GetGeneratedAttributeKind(attributeClass)) - { - case GeneratedAttributeKind.ShortCircuit: - shortCircuit = true; - continue; - case GeneratedAttributeKind.DisableValidation: - disableValidation = true; - continue; - case GeneratedAttributeKind.DisableRequestTimeout: - disableRequestTimeout = true; - withRequestTimeout = false; - requestTimeoutPolicyName = null; - continue; - case GeneratedAttributeKind.RequestTimeout: - { - disableRequestTimeout = false; - withRequestTimeout = true; - - string? policyName = null; - if (attribute.ConstructorArguments.Length > 0) - policyName = attribute.ConstructorArguments[0].Value as string; - - policyName ??= GetNamedStringValue(attribute, PolicyNameAttributeNamedParameter); - requestTimeoutPolicyName = NormalizeOptionalString(policyName); - continue; - } - case GeneratedAttributeKind.Order: - if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int orderValue) - order = orderValue; - continue; - case GeneratedAttributeKind.MapGroup: - { - var groupName = GetNamedStringValue(attribute, NameAttributeNamedParameter); - if (!string.IsNullOrEmpty(groupName)) - endpointGroupName = groupName; - continue; - } - case GeneratedAttributeKind.Summary: - if (attribute.ConstructorArguments.Length > 0) - { - var summaryValue = NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string); - if (!string.IsNullOrEmpty(summaryValue)) - summary = summaryValue; - } - continue; - case GeneratedAttributeKind.Accepts: - TryAddAcceptsMetadata(attribute, attributeClass, ref accepts); - continue; - case GeneratedAttributeKind.ProducesResponse: - TryAddProducesMetadata(attribute, attributeClass, ref produces); - continue; - case GeneratedAttributeKind.RequireAuthorization: - requireAuthorization = true; - hasRequireAuthorizationAttribute = true; - if (attribute.ConstructorArguments.Length == 1) - { - var arg = attribute.ConstructorArguments[0]; - MergeInto(ref authorizationPolicies, arg.Values); - } - - continue; - case GeneratedAttributeKind.RequireCors: - requireCors = true; - corsPolicyName = attribute.ConstructorArguments.Length > 0 - ? NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string) - : null; - continue; - case GeneratedAttributeKind.RequireHost: - if (attribute.ConstructorArguments.Length == 1) - { - var arg = attribute.ConstructorArguments[0]; - if (arg is { Kind: TypedConstantKind.Array, Values.Length: > 0 }) - MergeInto(ref requiredHosts, arg.Values); - else if (arg.Value is string singleHost && !string.IsNullOrWhiteSpace(singleHost)) - MergeInto(ref requiredHosts, [singleHost.Trim()]); - } - - continue; - case GeneratedAttributeKind.RequireRateLimiting: - { - var policyName = attribute.ConstructorArguments.Length > 0 - ? NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string) - : null; - - if (!string.IsNullOrEmpty(policyName)) - { - requireRateLimiting = true; - rateLimitingPolicyName = policyName; - } - - continue; - } - case GeneratedAttributeKind.EndpointFilter: - TryAddEndpointFilter(attribute, attributeClass, ref endpointFilters, ref endpointFilterSet); - continue; - case GeneratedAttributeKind.DisableAntiforgery: - disableAntiforgery = true; - continue; - case GeneratedAttributeKind.ProducesProblem: - { - var statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesProblemStatusCode - ? producesProblemStatusCode - : 500; - var contentType = attribute.ConstructorArguments.Length > 1 - ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) - : null; - var additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; - - var producesProblemList = producesProblem ??= []; - producesProblemList.Add(new ProducesProblemMetadata(statusCode, contentType, additionalContentTypes)); - continue; - } - case GeneratedAttributeKind.ProducesValidationProblem: - { - var statusCode = - attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesValidationProblemStatusCode - ? producesValidationProblemStatusCode - : 400; - var contentType = attribute.ConstructorArguments.Length > 1 - ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) - : null; - var additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; - - var producesValidationProblemList = producesValidationProblem ??= []; - producesValidationProblemList.Add(new ProducesValidationProblemMetadata(statusCode, contentType, additionalContentTypes)); - continue; - } - } - - if (IsAttribute(attributeClass, AllowAnonymousAttributeName, AspNetCoreAuthorizationNamespaceParts)) - { - allowAnonymous = true; - hasAllowAnonymousAttribute = true; - continue; - } - - if (IsAttribute(attributeClass, "TagsAttribute", AspNetCoreHttpNamespaceParts)) - { - if (attribute.ConstructorArguments.Length > 0) - { - var arg = attribute.ConstructorArguments[0]; - MergeInto(ref tags, arg.Values); - } - - continue; - } - - if (IsAttribute(attributeClass, "ExcludeFromDescriptionAttribute", AspNetCoreRoutingNamespaceParts)) - excludeFromDescription = true; - } - } - - private static void MergeInto(ref EquatableImmutableArray? target, IEnumerable values) - { - var merged = MergeUnion(target, values); - target = merged.Count > 0 ? merged : null; - } - - private static void MergeInto(ref EquatableImmutableArray? target, ImmutableArray values) - { - if (values.IsDefaultOrEmpty) - return; - List? normalized = null; - foreach (var value in values) - { - if (value.Value is not string stringValue) - continue; - - var trimmed = NormalizeOptionalString(stringValue); - if (trimmed is not { Length: > 0 }) - continue; - - normalized ??= new List(values.Length); - normalized.Add(trimmed); - } - - if (normalized is { Count: > 0 }) - MergeInto(ref target, normalized); - } - - private static EquatableImmutableArray? ToEquatableOrNull(List? values) - { - return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null; - } - - private static string NormalizeRequiredContentType(string? contentType, string defaultValue) - { - return string.IsNullOrWhiteSpace(contentType) ? defaultValue : contentType!.Trim(); - } - - private static string? NormalizeOptionalContentType(string? contentType) - { - return string.IsNullOrWhiteSpace(contentType) ? null : contentType!.Trim(); - } - - private static string? NormalizeOptionalString(string? value) - { - return string.IsNullOrWhiteSpace(value) ? null : value!.Trim(); - } [SuppressMessage("Major Code Smell", "S3398:Move this method into a class of its own", Justification = "Shared helper for multiple caching paths.")] - private static string? GetMapGroupPattern(INamedTypeSymbol classSymbol) + internal static string? GetMapGroupPattern(INamedTypeSymbol classSymbol) { foreach (var attribute in classSymbol.GetAttributes()) { @@ -490,7 +186,7 @@ private static string NormalizeRequiredContentType(string? contentType, string d if (attributeClass is null) continue; - if (GetGeneratedAttributeKind(attributeClass) != GeneratedAttributeKind.MapGroup) + if (EndpointConfigurationFactory.GetGeneratedAttributeKind(attributeClass) != GeneratedAttributeKind.MapGroup) continue; if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is string pattern) @@ -501,7 +197,7 @@ private static string NormalizeRequiredContentType(string? contentType, string d } [SuppressMessage("Major Code Smell", "S3398:Move this method into a class of its own", Justification = "Shared helper for multiple caching paths.")] - private static string GetMapGroupIdentifier(string className) + internal static string GetMapGroupIdentifier(string className) { if (className.StartsWith(GlobalPrefix, StringComparison.Ordinal)) className = className.Substring(GlobalPrefix.Length); @@ -516,267 +212,6 @@ private static string GetMapGroupIdentifier(string className) return StringBuilderPool.ToStringAndReturn(builder); } - private static EquatableImmutableArray? GetStringArrayValues(TypedConstant typedConstant) - { - if (typedConstant.Kind != TypedConstantKind.Array || typedConstant.Values.IsDefaultOrEmpty) - return null; - - var builder = ImmutableArray.CreateBuilder(typedConstant.Values.Length); - foreach (var value in typedConstant.Values) - { - if (value.Value is string s && !string.IsNullOrWhiteSpace(s)) - builder.Add(s.Trim()); - } - - return builder.Count > 0 ? builder.ToEquatableImmutable() : null; - } - - private static GeneratedAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) - { - var definition = attributeClass.OriginalDefinition; - var cacheEntry = - GeneratedAttributeKindCache.GetValue(definition, static def => new GeneratedAttributeKindCacheEntry(GetGeneratedAttributeKindCore(def))); - - return cacheEntry.Kind; - } - - private static GeneratedAttributeKind GetGeneratedAttributeKindCore(INamedTypeSymbol definition) - { - if (!IsInNamespace(definition.ContainingNamespace, AttributesNamespaceParts)) - return GeneratedAttributeKind.None; - - return definition.Name switch - { - ShortCircuitAttributeName => GeneratedAttributeKind.ShortCircuit, - DisableValidationAttributeName => GeneratedAttributeKind.DisableValidation, - DisableRequestTimeoutAttributeName => GeneratedAttributeKind.DisableRequestTimeout, - RequestTimeoutAttributeName => GeneratedAttributeKind.RequestTimeout, - OrderAttributeName => GeneratedAttributeKind.Order, - MapGroupAttributeName => GeneratedAttributeKind.MapGroup, - SummaryAttributeName => GeneratedAttributeKind.Summary, - AcceptsAttributeName => GeneratedAttributeKind.Accepts, - ProducesResponseAttributeName => GeneratedAttributeKind.ProducesResponse, - RequireAuthorizationAttributeName => GeneratedAttributeKind.RequireAuthorization, - RequireCorsAttributeName => GeneratedAttributeKind.RequireCors, - RequireHostAttributeName => GeneratedAttributeKind.RequireHost, - RequireRateLimitingAttributeName => GeneratedAttributeKind.RequireRateLimiting, - EndpointFilterAttributeName => GeneratedAttributeKind.EndpointFilter, - DisableAntiforgeryAttributeName => GeneratedAttributeKind.DisableAntiforgery, - ProducesProblemAttributeName => GeneratedAttributeKind.ProducesProblem, - ProducesValidationProblemAttributeName => GeneratedAttributeKind.ProducesValidationProblem, - _ => GeneratedAttributeKind.None, - }; - } - - private static bool IsAttribute(INamedTypeSymbol attributeClass, string attributeName, string[] namespaceParts) - { - var definition = attributeClass.OriginalDefinition; - return definition.Name == attributeName && IsInNamespace(definition.ContainingNamespace, namespaceParts); - } - - private static BindingSource GetBindingSourceFromAttributeClass(INamedTypeSymbol attributeClass) - { - var definition = attributeClass.OriginalDefinition; - var namespaceSymbol = definition.ContainingNamespace; - - return definition.Name switch - { - "FromRouteAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromRoute, - "FromQueryAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromQuery, - "FromHeaderAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromHeader, - "FromBodyAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromBody, - "FromFormAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromForm, - "FromServicesAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromServices, - "FromKeyedServicesAttribute" when IsInNamespace(namespaceSymbol, ExtensionsDependencyInjectionNamespaceParts) - => BindingSource.FromKeyedServices, - "AsParametersAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreHttpNamespaceParts) => BindingSource.AsParameters, - _ => BindingSource.None, - }; - } - - private static bool IsInNamespace(INamespaceSymbol? namespaceSymbol, string[] namespaceParts) - { - for (var i = namespaceParts.Length - 1; i >= 0; i--) - { - if (namespaceSymbol is null || namespaceSymbol.Name != namespaceParts[i]) - return false; - - namespaceSymbol = namespaceSymbol.ContainingNamespace; - } - - return namespaceSymbol is null || namespaceSymbol.IsGlobalNamespace; - } - - private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? accepts) - { - string? requestType; - string contentType; - EquatableImmutableArray? additionalContentTypes; - var isOptional = GetNamedBoolValue(attribute, IsOptionalAttributeNamedParameter); - - if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - { - requestType = attributeClass.TypeArguments[0] - .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - contentType = attribute.ConstructorArguments.Length > 0 - ? NormalizeRequiredContentType(attribute.ConstructorArguments[0].Value as string, "application/json") - : "application/json"; - additionalContentTypes = attribute.ConstructorArguments.Length > 1 ? GetStringArrayValues(attribute.ConstructorArguments[1]) : null; - } - else if (GetNamedTypeSymbol(attribute, RequestTypeAttributeNamedParameter) is { } requestTypeSymbol) - { - requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - contentType = attribute.ConstructorArguments.Length > 0 - ? NormalizeRequiredContentType(attribute.ConstructorArguments[0].Value as string, "application/json") - : "application/json"; - additionalContentTypes = attribute.ConstructorArguments.Length > 1 ? GetStringArrayValues(attribute.ConstructorArguments[1]) : null; - } - else - { - return; - } - - var acceptsList = accepts ??= []; - acceptsList.Add(new AcceptsMetadata(requestType, contentType, additionalContentTypes, isOptional)); - } - - private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? produces) - { - string? responseType; - int statusCode; - string? contentType; - EquatableImmutableArray? additionalContentTypes; - - if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - { - responseType = attributeClass.TypeArguments[0] - .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode - ? producesStatusCode - : 200; - contentType = attribute.ConstructorArguments.Length > 1 ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) : null; - additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; - } - else if (GetNamedTypeSymbol(attribute, ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol) - { - responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode - ? producesStatusCode - : 200; - contentType = attribute.ConstructorArguments.Length > 1 ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) : null; - additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; - } - else - { - return; - } - - var producesList = produces ??= []; - producesList.Add(new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes)); - } - - private static void TryAddEndpointFilter( - AttributeData attribute, - INamedTypeSymbol attributeClass, - ref List? endpointFilters, - ref HashSet? endpointFilterSet) - { - if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - { - TryAddEndpointFilterType(attributeClass.TypeArguments[0], ref endpointFilters, ref endpointFilterSet); - return; - } - - if (attribute.ConstructorArguments.Length == 0) - return; - - if (attribute.ConstructorArguments[0].Value is ITypeSymbol filterTypeSymbol) - TryAddEndpointFilterType(filterTypeSymbol, ref endpointFilters, ref endpointFilterSet); - } - - private static void TryAddEndpointFilterType( - ITypeSymbol? typeSymbol, - ref List? endpointFilters, - ref HashSet? endpointFilterSet) - { - if (typeSymbol is null or ITypeParameterSymbol or IErrorTypeSymbol) - return; - - var displayString = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - if (string.IsNullOrWhiteSpace(displayString)) - return; - - endpointFilterSet ??= new HashSet(StringComparer.Ordinal); - if (!endpointFilterSet.Add(displayString)) - return; - - endpointFilters ??= []; - endpointFilters.Add(displayString); - } - - private static ITypeSymbol? GetNamedTypeSymbol(AttributeData attribute, string namedParameter) - { - foreach (var namedArg in attribute.NamedArguments) - { - if (namedArg.Key == namedParameter && namedArg.Value.Value is ITypeSymbol typeSymbol) - return typeSymbol; - } - - return null; - } - - private static bool GetNamedBoolValue(AttributeData attribute, string namedParameter, bool defaultValue = false) - { - foreach (var namedArg in attribute.NamedArguments) - { - if (namedArg.Key == namedParameter && namedArg.Value.Value is bool boolValue) - return boolValue; - } - - return defaultValue; - } - - private static string? GetNamedStringValue(AttributeData attribute, string namedParameter) - { - foreach (var namedArg in attribute.NamedArguments) - { - if (namedArg.Key == namedParameter && namedArg.Value.Value is string stringValue) - return NormalizeOptionalString(stringValue); - } - - return null; - } - - private static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, IEnumerable values) - { - List? list = null; - HashSet? seen = null; - - if (existing is { Count: > 0 }) - { - var count = existing.Value.Count; - list = new List(count + 4); - list.AddRange(existing.Value); - seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); - } - - foreach (var value in values) - { - var normalized = NormalizeOptionalString(value); - if (normalized is not { Length: > 0 }) - continue; - - seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); - if (!seen.Add(normalized)) - continue; - - list ??= []; - - list.Add(normalized); - } - - return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; - } private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol methodSymbol, CancellationToken cancellationToken) { @@ -785,7 +220,7 @@ private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol method var name = methodSymbol.Name; var isStatic = methodSymbol.IsStatic; var isAwaitable = methodSymbol.ReturnType.IsTask(out _) || methodSymbol.ReturnType.IsValueTask(out _); - var parameters = GetRequestHandlerParameters(methodSymbol, cancellationToken); + var parameters = RequestHandlerParameterHelper.Build(methodSymbol, cancellationToken); var requestHandlerMethod = new RequestHandlerMethod(name, isStatic, isAwaitable, parameters); @@ -812,7 +247,7 @@ private static CompilationTypeCache GetCompilationTypeCache(Compilation compilat } [SuppressMessage("Major Code Smell", "S3398:Move this method into a class of its own", Justification = "Shared helper reused by caching infrastructure.")] - private static ConfigureMethodDetails GetConfigureMethodDetails( + internal static ConfigureMethodDetails GetConfigureMethodDetails( INamedTypeSymbol classSymbol, INamedTypeSymbol? endpointConventionBuilderSymbol, INamedTypeSymbol? serviceProviderSymbol, @@ -953,81 +388,6 @@ private static bool MatchesServiceProvider(ITypeSymbol typeSymbol) return string.Equals(containingNamespace, "System", StringComparison.Ordinal); } - private static EquatableImmutableArray GetRequestHandlerParameters(IMethodSymbol methodSymbol, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - var methodParameters = ImmutableArray.CreateBuilder(methodSymbol.Parameters.Length); - foreach (var parameter in methodSymbol.Parameters) - { - cancellationToken.ThrowIfCancellationRequested(); - - var source = BindingSource.None; - TypedConstant? typedKey = null; - string? bindingName = null; - - foreach (var attribute in parameter.GetAttributes()) - { - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - var attributeSource = GetBindingSourceFromAttributeClass(attributeClass); - if (attributeSource == BindingSource.None) - continue; - - source = attributeSource; - switch (attributeSource) - { - case BindingSource.FromRoute: - case BindingSource.FromQuery: - case BindingSource.FromHeader: - case BindingSource.FromForm: - bindingName = GetBindingAttributeName(attribute) ?? bindingName; - break; - case BindingSource.FromKeyedServices: - typedKey = attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0] : null; - break; - } - } - - var parameterName = parameter.Name; - var parameterType = parameter.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var key = typedKey.HasValue ? ConstLiteral(typedKey.Value) : null; - var bindingPrefix = GetBindingSourceAttribute(source, key, bindingName); - methodParameters.Add(new Parameter(parameterName, parameterType, bindingPrefix)); - } - - return methodParameters.ToEquatableImmutable(); - } - - private static string? GetBindingAttributeName(AttributeData attribute) - { - foreach (var namedArg in attribute.NamedArguments) - { - if (string.Equals(namedArg.Key, NameAttributeNamedParameter, StringComparison.Ordinal) && namedArg.Value.Value is string namedValue) - { - var normalized = NormalizeBindingName(namedValue); - if (normalized is not null) - return normalized; - } - } - - if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is string constructorName) - return NormalizeBindingName(constructorName); - - return null; - } - - private static string? NormalizeBindingName(string? value) - { - if (string.IsNullOrWhiteSpace(value)) - return null; - - var trimmed = value!.Trim(); - return trimmed.Length > 0 ? trimmed : null; - } - private static void GenerateSource(SourceProductionContext context, ImmutableArray requestHandlers) { context.CancellationToken.ThrowIfCancellationRequested(); @@ -1748,7 +1108,7 @@ private static RequestHandlerMetadata MergeRequestHandlerMetadata(RequestHandler if (second is not { Count: > 0 }) return first; - var merged = MergeUnion(first, second.Value); + var merged = EndpointConfigurationFactory.MergeUnion(first, second.Value); return merged.Count > 0 ? merged : null; } @@ -1778,31 +1138,6 @@ private static RequestHandlerMetadata MergeRequestHandlerMetadata(RequestHandler }; } - private static string GetBindingSourceAttribute(BindingSource source, string? key, string? bindingName) - { - return source switch - { - BindingSource.None => "", - BindingSource.FromRoute => FormatBindingAttribute("FromRoute", bindingName), - BindingSource.FromQuery => FormatBindingAttribute("FromQuery", bindingName), - BindingSource.FromHeader => FormatBindingAttribute("FromHeader", bindingName), - BindingSource.FromBody => FormatBindingAttribute("FromBody", bindingName), - BindingSource.FromForm => FormatBindingAttribute("FromForm", bindingName), - BindingSource.FromServices => "[FromServices] ", - BindingSource.FromKeyedServices => $"[FromKeyedServices({key})] ", - BindingSource.AsParameters => "[AsParameters] ", - _ => throw new NotImplementedException(), - }; - } - - private static string FormatBindingAttribute(string attributeName, string? bindingName) - { - if (bindingName is null) - return $"[{attributeName}] "; - - return $"[{attributeName}(Name = {StringLiteral(bindingName)})] "; - } - private static StringBuilder GetUseEndpointHandlersStringBuilder(ImmutableArray requestHandlers) { const int baseSize = 4096; @@ -1819,7 +1154,7 @@ private static StringBuilder GetUseEndpointHandlersStringBuilder(ImmutableArray< } [SuppressMessage("Globalization", "CA1308: Normalize strings to uppercase", Justification = "C# boolean literals must be lowercase.")] - private static string ConstLiteral(TypedConstant tc) + internal static string ConstLiteral(TypedConstant tc) { if (tc.IsNull) return "null"; @@ -1877,7 +1212,7 @@ private static string IntegralLiteral(object? value, SpecialType underlying) }; } - private static string StringLiteral(string? value) + internal static string StringLiteral(string? value) { if (value is null) return "null"; From c6cafe8dccb9070c99c2d131f1b59126b97fb56a Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 17:57:36 -0500 Subject: [PATCH 02/32] Cleanup. --- .../Common/Constants.GeneratedSources.cs | 1 - .../Common/EndpointAttributeState.cs | 2 - .../Common/EndpointConfigurationFactory.cs | 1 - .../Common/RequestHandlerClassCacheEntry.cs | 205 +++++++++- .../Common/RequestHandlerComparer.cs | 3 - .../Common/RequestHandlerParameterHelper.cs | 6 +- .../Common/StringExtensions.cs | 68 ++++ .../Common/TypeSymbolExtensions.cs | 96 ++--- .../Common/TypedConstantExtensions.cs | 82 ++++ src/GeneratedEndpoints/MinimalApiGenerator.cs | 377 ++---------------- 10 files changed, 404 insertions(+), 437 deletions(-) create mode 100644 src/GeneratedEndpoints/Common/StringExtensions.cs create mode 100644 src/GeneratedEndpoints/Common/TypedConstantExtensions.cs diff --git a/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs index b2ef684..534d742 100644 --- a/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs +++ b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs @@ -1,7 +1,6 @@ using System.Collections.Immutable; using System.Text; using Microsoft.CodeAnalysis.Text; -using GeneratedEndpoints; namespace GeneratedEndpoints.Common; diff --git a/src/GeneratedEndpoints/Common/EndpointAttributeState.cs b/src/GeneratedEndpoints/Common/EndpointAttributeState.cs index e455138..3ac4ed1 100644 --- a/src/GeneratedEndpoints/Common/EndpointAttributeState.cs +++ b/src/GeneratedEndpoints/Common/EndpointAttributeState.cs @@ -1,5 +1,3 @@ -using System.Collections.Generic; - namespace GeneratedEndpoints.Common; internal struct EndpointAttributeState diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index 60e4e33..df02397 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -1,4 +1,3 @@ -using System.Collections.Generic; using System.Collections.Immutable; using System.Runtime.CompilerServices; using Microsoft.CodeAnalysis; diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs index 28dadc2..647cc95 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs @@ -1,6 +1,6 @@ -using System.Threading; -using GeneratedEndpoints; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints.Common; @@ -10,11 +10,7 @@ internal sealed class RequestHandlerClassCacheEntry private RequestHandlerClass _value; private bool _initialized; - public RequestHandlerClass GetOrCreate( - INamedTypeSymbol classSymbol, - CompilationTypeCache compilationCache, - CancellationToken cancellationToken - ) + public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, CompilationTypeCache compilationCache, CancellationToken cancellationToken) { if (_initialized) return _value; @@ -28,28 +24,193 @@ CancellationToken cancellationToken var name = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var isStatic = classSymbol.IsStatic; - var configureMethodDetails = MinimalApiGenerator.GetConfigureMethodDetails( - classSymbol, - compilationCache.EndpointConventionBuilderSymbol, - compilationCache.ServiceProviderSymbol, - cancellationToken + var configureMethodDetails = GetConfigureMethodDetails(classSymbol, compilationCache.EndpointConventionBuilderSymbol, + compilationCache.ServiceProviderSymbol, cancellationToken ); - var mapGroupPattern = MinimalApiGenerator.GetMapGroupPattern(classSymbol); - var mapGroupIdentifier = mapGroupPattern is null ? null : MinimalApiGenerator.GetMapGroupIdentifier(name); + var mapGroupPattern = GetMapGroupPattern(classSymbol); + var mapGroupIdentifier = mapGroupPattern is null ? null : GetMapGroupIdentifier(name); var classConfiguration = EndpointConfigurationFactory.Create(classSymbol.GetAttributes(), null, null, null, false); - _value = new RequestHandlerClass( - name, - isStatic, - configureMethodDetails.HasConfigureMethod, - configureMethodDetails.ConfigureMethodAcceptsServiceProvider, - mapGroupPattern, - mapGroupIdentifier, - classConfiguration + _value = new RequestHandlerClass(name, isStatic, configureMethodDetails.HasConfigureMethod, + configureMethodDetails.ConfigureMethodAcceptsServiceProvider, mapGroupPattern, mapGroupIdentifier, classConfiguration ); _initialized = true; return _value; } } + + private static ConfigureMethodDetails GetConfigureMethodDetails( + INamedTypeSymbol classSymbol, + INamedTypeSymbol? endpointConventionBuilderSymbol, + INamedTypeSymbol? serviceProviderSymbol, + CancellationToken cancellationToken + ) + { + cancellationToken.ThrowIfCancellationRequested(); + + var hasConfigureMethod = false; + var acceptsServiceProvider = false; + foreach (var member in classSymbol.GetMembers(ConfigureMethodName)) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (member is not IMethodSymbol methodSymbol) + continue; + + if (IsConfigureMethod(methodSymbol, endpointConventionBuilderSymbol, serviceProviderSymbol, out var methodAcceptsServiceProvider)) + { + hasConfigureMethod = true; + if (methodAcceptsServiceProvider) + { + acceptsServiceProvider = true; + break; + } + } + } + + return new ConfigureMethodDetails(hasConfigureMethod, acceptsServiceProvider); + } + + private static bool IsConfigureMethod( + IMethodSymbol methodSymbol, + INamedTypeSymbol? endpointConventionBuilderSymbol, + INamedTypeSymbol? serviceProviderSymbol, + out bool acceptsServiceProvider + ) + { + acceptsServiceProvider = false; + + if (!methodSymbol.IsStatic) + return false; + + if (methodSymbol.TypeParameters.Length != 1) + return false; + + if (methodSymbol.Parameters.Length is < 1 or > 2) + return false; + + var builderTypeParameter = methodSymbol.TypeParameters[0]; + var builderParameter = methodSymbol.Parameters[0]; + + if (!SymbolEqualityComparer.Default.Equals(builderParameter.Type, builderTypeParameter)) + return false; + + if (methodSymbol.Parameters.Length == 2) + { + var serviceProviderParameter = methodSymbol.Parameters[1]; + if (!IsServiceProviderParameter(serviceProviderParameter.Type, serviceProviderSymbol)) + return false; + + acceptsServiceProvider = true; + } + + if (!methodSymbol.ReturnsVoid) + return false; + + if (!HasEndpointConventionBuilderConstraint(builderTypeParameter, methodSymbol, endpointConventionBuilderSymbol)) + return false; + + return true; + } + + private static bool IsServiceProviderParameter(ITypeSymbol typeSymbol, INamedTypeSymbol? serviceProviderSymbol) + { + if (serviceProviderSymbol is not null) + return SymbolEqualityComparer.Default.Equals(typeSymbol, serviceProviderSymbol); + + return MatchesServiceProvider(typeSymbol); + } + + private static bool HasEndpointConventionBuilderConstraint( + ITypeParameterSymbol builderTypeParameter, + IMethodSymbol methodSymbol, + INamedTypeSymbol? endpointConventionBuilderSymbol + ) + { + var symbolMatches = builderTypeParameter.ConstraintTypes.Any(constraint => + endpointConventionBuilderSymbol is not null + ? SymbolEqualityComparer.Default.Equals(constraint, endpointConventionBuilderSymbol) + : MatchesEndpointConventionBuilder(constraint) + ); + + if (symbolMatches) + return true; + + return methodSymbol.DeclaringSyntaxReferences + .Select(reference => reference.GetSyntax()) + .OfType() + .SelectMany(methodSyntax => methodSyntax.ConstraintClauses) + .Where(clause => string.Equals(clause.Name.Identifier.ValueText, builderTypeParameter.Name, StringComparison.Ordinal)) + .SelectMany(clause => clause.Constraints.OfType()) + .Any(constraint => IsEndpointConventionBuilderIdentifier(constraint.Type)); + } + + private static bool IsEndpointConventionBuilderIdentifier(TypeSyntax typeSyntax) + { + return typeSyntax switch + { + QualifiedNameSyntax qualified => IsEndpointConventionBuilderIdentifier(qualified.Right), + AliasQualifiedNameSyntax alias => IsEndpointConventionBuilderIdentifier(alias.Name), + SimpleNameSyntax simple => string.Equals(simple.Identifier.ValueText, "IEndpointConventionBuilder", StringComparison.Ordinal), + _ => false, + }; + } + + private static bool MatchesEndpointConventionBuilder(ITypeSymbol typeSymbol) + { + if (typeSymbol is not INamedTypeSymbol namedType) + return false; + + if (!string.Equals(namedType.Name, "IEndpointConventionBuilder", StringComparison.Ordinal)) + return false; + + var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? string.Empty; + return string.Equals(containingNamespace, "Microsoft.AspNetCore.Builder", StringComparison.Ordinal); + } + + private static bool MatchesServiceProvider(ITypeSymbol typeSymbol) + { + if (typeSymbol is not INamedTypeSymbol namedType) + return false; + + if (!string.Equals(namedType.Name, "IServiceProvider", StringComparison.Ordinal)) + return false; + + var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? string.Empty; + return string.Equals(containingNamespace, "System", StringComparison.Ordinal); + } + + private static string? GetMapGroupPattern(INamedTypeSymbol classSymbol) + { + foreach (var attribute in classSymbol.GetAttributes()) + { + var attributeClass = attribute.AttributeClass; + if (attributeClass is null) + continue; + + if (EndpointConfigurationFactory.GetGeneratedAttributeKind(attributeClass) != GeneratedAttributeKind.MapGroup) + continue; + + if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is string pattern) + return pattern.Trim(); + } + + return null; + } + + private static string GetMapGroupIdentifier(string className) + { + if (className.StartsWith(GlobalPrefix, StringComparison.Ordinal)) + className = className.Substring(GlobalPrefix.Length); + + var builder = StringBuilderPool.Get(className.Length + 8); + builder.Append('_'); + + foreach (var character in className) + builder.Append(char.IsLetterOrDigit(character) ? character : '_'); + + builder.Append("_Group"); + return StringBuilderPool.ToStringAndReturn(builder); + } } diff --git a/src/GeneratedEndpoints/Common/RequestHandlerComparer.cs b/src/GeneratedEndpoints/Common/RequestHandlerComparer.cs index f6b78eb..4d3d61a 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerComparer.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerComparer.cs @@ -1,6 +1,3 @@ -using System; -using System.Collections.Generic; - namespace GeneratedEndpoints.Common; internal sealed class RequestHandlerComparer : IComparer diff --git a/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs b/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs index 1aaeb2f..45f5601 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs @@ -1,9 +1,7 @@ using System.Collections.Immutable; -using System.Threading; using Microsoft.CodeAnalysis; using static GeneratedEndpoints.Common.AttributeSymbolMatcher; using static GeneratedEndpoints.Common.Constants; -using static GeneratedEndpoints.MinimalApiGenerator; namespace GeneratedEndpoints.Common; @@ -49,7 +47,7 @@ public static EquatableImmutableArray Build(IMethodSymbol methodSymbo var parameterName = parameter.Name; var parameterType = parameter.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var key = typedKey.HasValue ? ConstLiteral(typedKey.Value) : null; + var key = typedKey?.ToConstLiteral(); var bindingPrefix = GetBindingSourceAttribute(source, key, bindingName); methodParameters.Add(new Parameter(parameterName, parameterType, bindingPrefix)); } @@ -126,6 +124,6 @@ private static string FormatBindingAttribute(string attributeName, string? bindi if (bindingName is null) return $"[{attributeName}] "; - return $"[{attributeName}(Name = {StringLiteral(bindingName)})] "; + return $"[{attributeName}(Name = {bindingName.ToStringLiteral()})] "; } } diff --git a/src/GeneratedEndpoints/Common/StringExtensions.cs b/src/GeneratedEndpoints/Common/StringExtensions.cs new file mode 100644 index 0000000..cff861b --- /dev/null +++ b/src/GeneratedEndpoints/Common/StringExtensions.cs @@ -0,0 +1,68 @@ +using System.Globalization; + +namespace GeneratedEndpoints.Common; + +internal static class StringExtensions +{ + public static string ToStringLiteral(this string? value) + { + if (value is null) + return "null"; + + var firstEscapeIndex = -1; + for (var i = 0; i < value.Length; i++) + { + var c = value[i]; + if (c == '\"' || c == '\\' || c == '\n' || c == '\r' || c == '\t' || c == '\0' || char.IsControl(c)) + { + firstEscapeIndex = i; + break; + } + } + + if (firstEscapeIndex < 0) + return string.Concat("\"", value, "\""); + + var sb = StringBuilderPool.Get(value.Length + 2); + sb.Append('"'); + if (firstEscapeIndex > 0) + sb.Append(value, 0, firstEscapeIndex); + + for (var i = firstEscapeIndex; i < value.Length; i++) + { + var c = value[i]; + switch (c) + { + case '\"': + sb.Append("\\\""); + break; + case '\\': + sb.Append("\\\\"); + break; + case '\n': + sb.Append("\\n"); + break; + case '\r': + sb.Append("\\r"); + break; + case '\t': + sb.Append("\\t"); + break; + case '\0': + sb.Append("\\0"); + break; + default: + if (char.IsControl(c)) + sb.Append("\\u") + .Append(((int)c).ToString("x4", CultureInfo.InvariantCulture)); + else + sb.Append(c); + + break; + } + } + + sb.Append('"'); + return StringBuilderPool.ToStringAndReturn(sb); + } +} diff --git a/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs b/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs index 1047514..0b69f45 100644 --- a/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs +++ b/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs @@ -4,79 +4,47 @@ namespace GeneratedEndpoints.Common; internal static class TypeSymbolExtensions { - public static bool IsValueTask(this ITypeSymbol symbol, out INamedTypeSymbol valueTaskSymbol) + public static bool IsAwaitable(this ITypeSymbol symbol) { - if (symbol is INamedTypeSymbol - { - MetadataName: "ValueTask`1", - ContainingNamespace: + return symbol switch + { + INamedTypeSymbol { - Name: "Tasks", + MetadataName: "ValueTask`1", ContainingNamespace: { - Name: "Threading", - ContainingNamespace: - { - Name: "System", - ContainingNamespace.IsGlobalNamespace: true, - }, + Name: "Tasks", + ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, }, - }, - } namedTypeSymbol) - { - valueTaskSymbol = namedTypeSymbol; - return true; - } - - valueTaskSymbol = null!; - return false; - } - - public static bool IsTask(this ITypeSymbol symbol, out INamedTypeSymbol valueTaskSymbol) - { - if (symbol is INamedTypeSymbol - { - MetadataName: "Task`1", - ContainingNamespace: + } + or INamedTypeSymbol { - Name: "Tasks", + MetadataName: "Task`1", ContainingNamespace: { - Name: "Threading", - ContainingNamespace: - { - Name: "System", - ContainingNamespace.IsGlobalNamespace: true, - }, + Name: "Tasks", + ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, }, - }, - } namedTypeSymbol) - { - valueTaskSymbol = namedTypeSymbol; - return true; - } - - valueTaskSymbol = null!; - return false; - } - - public static bool IsLightResults(this ITypeSymbol symbol, out INamedTypeSymbol lightResultsSymbol) - { - if (symbol is INamedTypeSymbol - { - Name: "Result", - ContainingNamespace: + } + or INamedTypeSymbol { - Name: "LightResults", - ContainingNamespace.IsGlobalNamespace: true, - }, - } namedTypeSymbol) - { - lightResultsSymbol = namedTypeSymbol; - return true; - } - - lightResultsSymbol = null!; - return false; + MetadataName: "ValueTask", + ContainingNamespace: + { + Name: "Tasks", + ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, + }, + } + or INamedTypeSymbol + { + MetadataName: "Task", + ContainingNamespace: + { + Name: "Tasks", + ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, + }, + } => true, + _ => false, + }; } } diff --git a/src/GeneratedEndpoints/Common/TypedConstantExtensions.cs b/src/GeneratedEndpoints/Common/TypedConstantExtensions.cs new file mode 100644 index 0000000..1f29648 --- /dev/null +++ b/src/GeneratedEndpoints/Common/TypedConstantExtensions.cs @@ -0,0 +1,82 @@ +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using Microsoft.CodeAnalysis; + +namespace GeneratedEndpoints.Common; + +internal static class TypedConstantExtensions +{ + [SuppressMessage("Globalization", "CA1308: Normalize strings to uppercase", Justification = "C# boolean literals must be lowercase.")] + public static string ToConstLiteral(this TypedConstant tc) + { + if (tc.IsNull) + return "null"; + var v = tc.Value; + var t = tc.Type; + if (t is null) + return "null"; + + if (t.TypeKind != TypeKind.Enum) + return t.SpecialType switch + { + SpecialType.System_String => ((string?)v).ToStringLiteral(), + SpecialType.System_Char => $"'{EscapeChar((char)v!)}'", + SpecialType.System_Boolean => ((bool)v!).ToString() + .ToLowerInvariant(), + SpecialType.System_Double => ((double)v!).ToString("R", CultureInfo.InvariantCulture), + SpecialType.System_Single => ((float)v!).ToString("R", CultureInfo.InvariantCulture) + "f", + SpecialType.System_Decimal => ((decimal)v!).ToString(CultureInfo.InvariantCulture) + "m", + SpecialType.System_SByte => ((sbyte)v!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Byte => ((byte)v!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Int16 => ((short)v!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_UInt16 => ((ushort)v!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Int32 => ((int)v!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_UInt32 => ((uint)v!).ToString(CultureInfo.InvariantCulture) + "u", + SpecialType.System_Int64 => ((long)v!).ToString(CultureInfo.InvariantCulture) + "L", + SpecialType.System_UInt64 => ((ulong)v!).ToString(CultureInfo.InvariantCulture) + "UL", + _ => (v?.ToString()).ToStringLiteral(), + }; + + var field = t.GetMembers() + .OfType() + .FirstOrDefault(f => f.HasConstantValue && Equals(f.ConstantValue, v)); + + if (field is not null) + return $"{t.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}.{field.Name}"; + + var underlying = ((INamedTypeSymbol)t).EnumUnderlyingType!; + var num = IntegralLiteral(v, underlying.SpecialType); + return $"({t.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}){num}"; + } + + private static string EscapeChar(char c) + { + return c switch + { + '\'' => "\\'", + '\\' => "\\\\", + '\n' => "\\n", + '\r' => "\\r", + '\t' => "\\t", + '\0' => "\\0", + _ when char.IsControl(c) => "\\u" + ((int)c).ToString("x4", CultureInfo.InvariantCulture), + _ => c.ToString(), + }; + } + + private static string IntegralLiteral(object? value, SpecialType underlying) + { + return underlying switch + { + SpecialType.System_SByte => ((sbyte)value!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Byte => ((byte)value!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Int16 => ((short)value!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_UInt16 => ((ushort)value!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Int32 => ((int)value!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_UInt32 => ((uint)value!).ToString(CultureInfo.InvariantCulture) + "u", + SpecialType.System_Int64 => ((long)value!).ToString(CultureInfo.InvariantCulture) + "L", + SpecialType.System_UInt64 => ((ulong)value!).ToString(CultureInfo.InvariantCulture) + "UL", + _ => "0", + }; + } +} diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index f63729c..1961ede 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -2,7 +2,6 @@ using System.Collections.Immutable; using System.ComponentModel; using System.Diagnostics.CodeAnalysis; -using System.Globalization; using System.Runtime.CompilerServices; using System.Text; using GeneratedEndpoints.Common; @@ -14,7 +13,7 @@ namespace GeneratedEndpoints; [Generator] -public sealed partial class MinimalApiGenerator : IIncrementalGenerator +public sealed class MinimalApiGenerator : IIncrementalGenerator { private static readonly ConditionalWeakTable CompilationTypeCaches = new(); private static readonly ConditionalWeakTable RequestHandlerClassCache = new(); @@ -25,8 +24,11 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var requestHandlerProviders = ImmutableArray.CreateBuilder>>(HttpAttributeDefinitions.Length); - foreach (var definition in HttpAttributeDefinitions) + // ReSharper disable once ForCanBeConvertedToForeach + // Do not refactor, use for loop to avoid allocations. + for (var index = 0; index < HttpAttributeDefinitions.Length; index++) { + var definition = HttpAttributeDefinitions[index]; var handlers = context.SyntaxProvider .ForAttributeWithMetadataName(definition.FullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) .WhereNotNull() @@ -40,6 +42,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterSourceOutput(requestHandlers, GenerateSource); } + private static IncrementalValueProvider> CombineRequestHandlers( ImmutableArray>> handlerProviders ) @@ -123,7 +126,7 @@ private static string RemoveAsyncSuffix(string methodName) return methodName; } - private static ( string HttpMethod, string Pattern, string? Name ) GetRequestHandlerAttribute(AttributeData attribute, CancellationToken cancellationToken) + private static (string HttpMethod, string Pattern, string? Name) GetRequestHandlerAttribute(AttributeData attribute, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -134,8 +137,11 @@ private static ( string HttpMethod, string Pattern, string? Name ) GetRequestHan var pattern = (attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : "") ?? ""; string? name = null; - foreach (var namedArg in attribute.NamedArguments) + // ReSharper disable once ForCanBeConvertedToForeach + // Do not refactor, use for loop to avoid allocations. + for (var index = 0; index < attribute.NamedArguments.Length; index++) { + var namedArg = attribute.NamedArguments[index]; switch (namedArg.Key) { case NameAttributeNamedParameter: @@ -163,63 +169,31 @@ private static (string? DisplayName, string? Description) GetDisplayAndDescripti if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DisplayNameAttribute), ComponentModelNamespaceParts)) { - displayName = EndpointConfigurationFactory.NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : null); + displayName = EndpointConfigurationFactory.NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 + ? attribute.ConstructorArguments[0].Value as string + : null + ); continue; } if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DescriptionAttribute), ComponentModelNamespaceParts)) - description = EndpointConfigurationFactory.NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : null); + description = EndpointConfigurationFactory.NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 + ? attribute.ConstructorArguments[0].Value as string + : null + ); } return (displayName, description); } - - - [SuppressMessage("Major Code Smell", "S3398:Move this method into a class of its own", Justification = "Shared helper for multiple caching paths.")] - internal static string? GetMapGroupPattern(INamedTypeSymbol classSymbol) - { - foreach (var attribute in classSymbol.GetAttributes()) - { - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - if (EndpointConfigurationFactory.GetGeneratedAttributeKind(attributeClass) != GeneratedAttributeKind.MapGroup) - continue; - - if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is string pattern) - return pattern.Trim(); - } - - return null; - } - - [SuppressMessage("Major Code Smell", "S3398:Move this method into a class of its own", Justification = "Shared helper for multiple caching paths.")] - internal static string GetMapGroupIdentifier(string className) - { - if (className.StartsWith(GlobalPrefix, StringComparison.Ordinal)) - className = className.Substring(GlobalPrefix.Length); - - var builder = StringBuilderPool.Get(className.Length + 8); - builder.Append('_'); - - foreach (var character in className) - builder.Append(char.IsLetterOrDigit(character) ? character : '_'); - - builder.Append("_Group"); - return StringBuilderPool.ToStringAndReturn(builder); - } - - private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol methodSymbol, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var name = methodSymbol.Name; var isStatic = methodSymbol.IsStatic; - var isAwaitable = methodSymbol.ReturnType.IsTask(out _) || methodSymbol.ReturnType.IsValueTask(out _); + var isAwaitable = methodSymbol.ReturnType.IsAwaitable(); var parameters = RequestHandlerParameterHelper.Build(methodSymbol, cancellationToken); var requestHandlerMethod = new RequestHandlerMethod(name, isStatic, isAwaitable, parameters); @@ -238,6 +212,7 @@ private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol method var typeCache = GetCompilationTypeCache(compilation); var cacheEntry = RequestHandlerClassCache.GetValue(classSymbol, static _ => new RequestHandlerClassCacheEntry()); var requestHandlerClass = cacheEntry.GetOrCreate(classSymbol, typeCache, cancellationToken); + return requestHandlerClass; } @@ -246,148 +221,6 @@ private static CompilationTypeCache GetCompilationTypeCache(Compilation compilat return CompilationTypeCaches.GetValue(compilation, static c => new CompilationTypeCache(c)); } - [SuppressMessage("Major Code Smell", "S3398:Move this method into a class of its own", Justification = "Shared helper reused by caching infrastructure.")] - internal static ConfigureMethodDetails GetConfigureMethodDetails( - INamedTypeSymbol classSymbol, - INamedTypeSymbol? endpointConventionBuilderSymbol, - INamedTypeSymbol? serviceProviderSymbol, - CancellationToken cancellationToken - ) - { - cancellationToken.ThrowIfCancellationRequested(); - - var hasConfigureMethod = false; - var acceptsServiceProvider = false; - foreach (var member in classSymbol.GetMembers(ConfigureMethodName)) - { - cancellationToken.ThrowIfCancellationRequested(); - - if (member is not IMethodSymbol methodSymbol) - continue; - - if (IsConfigureMethod(methodSymbol, endpointConventionBuilderSymbol, serviceProviderSymbol, out var methodAcceptsServiceProvider)) - { - hasConfigureMethod = true; - if (methodAcceptsServiceProvider) - { - acceptsServiceProvider = true; - break; - } - } - } - - return new ConfigureMethodDetails(hasConfigureMethod, acceptsServiceProvider); - } - - private static bool IsConfigureMethod( - IMethodSymbol methodSymbol, - INamedTypeSymbol? endpointConventionBuilderSymbol, - INamedTypeSymbol? serviceProviderSymbol, - out bool acceptsServiceProvider - ) - { - acceptsServiceProvider = false; - - if (!methodSymbol.IsStatic) - return false; - - if (methodSymbol.TypeParameters.Length != 1) - return false; - - if (methodSymbol.Parameters.Length is < 1 or > 2) - return false; - - var builderTypeParameter = methodSymbol.TypeParameters[0]; - var builderParameter = methodSymbol.Parameters[0]; - - if (!SymbolEqualityComparer.Default.Equals(builderParameter.Type, builderTypeParameter)) - return false; - - if (methodSymbol.Parameters.Length == 2) - { - var serviceProviderParameter = methodSymbol.Parameters[1]; - if (!IsServiceProviderParameter(serviceProviderParameter.Type, serviceProviderSymbol)) - return false; - - acceptsServiceProvider = true; - } - - if (!methodSymbol.ReturnsVoid) - return false; - - if (!HasEndpointConventionBuilderConstraint(builderTypeParameter, methodSymbol, endpointConventionBuilderSymbol)) - return false; - - return true; - } - - private static bool IsServiceProviderParameter(ITypeSymbol typeSymbol, INamedTypeSymbol? serviceProviderSymbol) - { - if (serviceProviderSymbol is not null) - return SymbolEqualityComparer.Default.Equals(typeSymbol, serviceProviderSymbol); - - return MatchesServiceProvider(typeSymbol); - } - - private static bool HasEndpointConventionBuilderConstraint( - ITypeParameterSymbol builderTypeParameter, - IMethodSymbol methodSymbol, - INamedTypeSymbol? endpointConventionBuilderSymbol - ) - { - var symbolMatches = builderTypeParameter.ConstraintTypes.Any(constraint => - endpointConventionBuilderSymbol is not null - ? SymbolEqualityComparer.Default.Equals(constraint, endpointConventionBuilderSymbol) - : MatchesEndpointConventionBuilder(constraint) - ); - - if (symbolMatches) - return true; - - return methodSymbol.DeclaringSyntaxReferences - .Select(reference => reference.GetSyntax()) - .OfType() - .SelectMany(methodSyntax => methodSyntax.ConstraintClauses) - .Where(clause => string.Equals(clause.Name.Identifier.ValueText, builderTypeParameter.Name, StringComparison.Ordinal)) - .SelectMany(clause => clause.Constraints.OfType()) - .Any(constraint => IsEndpointConventionBuilderIdentifier(constraint.Type)); - } - - private static bool IsEndpointConventionBuilderIdentifier(TypeSyntax typeSyntax) - { - return typeSyntax switch - { - QualifiedNameSyntax qualified => IsEndpointConventionBuilderIdentifier(qualified.Right), - AliasQualifiedNameSyntax alias => IsEndpointConventionBuilderIdentifier(alias.Name), - SimpleNameSyntax simple => string.Equals(simple.Identifier.ValueText, "IEndpointConventionBuilder", StringComparison.Ordinal), - _ => false, - }; - } - - private static bool MatchesEndpointConventionBuilder(ITypeSymbol typeSymbol) - { - if (typeSymbol is not INamedTypeSymbol namedType) - return false; - - if (!string.Equals(namedType.Name, "IEndpointConventionBuilder", StringComparison.Ordinal)) - return false; - - var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? string.Empty; - return string.Equals(containingNamespace, "Microsoft.AspNetCore.Builder", StringComparison.Ordinal); - } - - private static bool MatchesServiceProvider(ITypeSymbol typeSymbol) - { - if (typeSymbol is not INamedTypeSymbol namedType) - return false; - - if (!string.Equals(namedType.Name, "IServiceProvider", StringComparison.Ordinal)) - return false; - - var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? string.Empty; - return string.Equals(containingNamespace, "System", StringComparison.Ordinal); - } - private static void GenerateSource(SourceProductionContext context, ImmutableArray requestHandlers) { context.CancellationToken.ThrowIfCancellationRequested(); @@ -639,7 +472,7 @@ private static void GenerateUseEndpointHandlersClass(SourceProductionContext con source.Append(" var "); source.Append(groupedClass.MapGroupBuilderIdentifier); source.Append(" = builder.MapGroup("); - source.Append(StringLiteral(groupedClass.MapGroupPattern!)); + source.Append(groupedClass.MapGroupPattern!.ToStringLiteral()); source.Append(')'); AppendEndpointConfiguration(source, " ", groupedClass.Configuration, false); source.AppendLine(";"); @@ -730,7 +563,7 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl source.Append(".MapFallback("); if (!string.IsNullOrEmpty(requestHandler.Pattern)) { - source.Append(StringLiteral(requestHandler.Pattern)); + source.Append(requestHandler.Pattern.ToStringLiteral()); source.Append(", "); } } @@ -740,7 +573,7 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl source.Append(".Map"); source.Append(mapMethodSuffix ?? "Methods"); source.Append('('); - source.Append(StringLiteral(requestHandler.Pattern)); + source.Append(requestHandler.Pattern.ToStringLiteral()); source.Append(", "); if (mapMethodSuffix is null) { @@ -822,7 +655,7 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.AppendLine(); source.Append(indent); source.Append(".WithName("); - source.Append(StringLiteral(metadata.Name)); + source.Append(metadata.Name.ToStringLiteral()); source.Append(')'); } @@ -831,7 +664,7 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.AppendLine(); source.Append(indent); source.Append(".WithDisplayName("); - source.Append(StringLiteral(metadata.DisplayName)); + source.Append(metadata.DisplayName.ToStringLiteral()); source.Append(')'); } @@ -840,7 +673,7 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.AppendLine(); source.Append(indent); source.Append(".WithSummary("); - source.Append(StringLiteral(metadata.Summary)); + source.Append(metadata.Summary.ToStringLiteral()); source.Append(')'); } @@ -849,7 +682,7 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.AppendLine(); source.Append(indent); source.Append(".WithDescription("); - source.Append(StringLiteral(metadata.Description)); + source.Append(metadata.Description.ToStringLiteral()); source.Append(')'); } @@ -858,7 +691,7 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.AppendLine(); source.Append(indent); source.Append(".WithGroupName("); - source.Append(StringLiteral(configuration.EndpointGroupName)); + source.Append(configuration.EndpointGroupName.ToStringLiteral()); source.Append(')'); } @@ -898,7 +731,7 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.Append('('); if (accepts.IsOptional) source.Append("isOptional: true, "); - source.Append(StringLiteral(accepts.ContentType)); + source.Append(accepts.ContentType.ToStringLiteral()); AppendAdditionalContentTypes(source, accepts.AdditionalContentTypes); source.Append(')'); } @@ -963,7 +796,7 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind if (!string.IsNullOrEmpty(configuration.CorsPolicyName)) { source.Append(".RequireCors("); - source.Append(StringLiteral(configuration.CorsPolicyName)); + source.Append(configuration.CorsPolicyName.ToStringLiteral()); source.Append(')'); } else @@ -986,7 +819,7 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.AppendLine(); source.Append(indent); source.Append(".RequireRateLimiting("); - source.Append(StringLiteral(configuration.RateLimitingPolicyName)); + source.Append(configuration.RateLimitingPolicyName.ToStringLiteral()); source.Append(')'); } @@ -1032,7 +865,7 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind if (!string.IsNullOrEmpty(configuration.RequestTimeoutPolicyName)) { source.Append(".WithRequestTimeout("); - source.Append(StringLiteral(configuration.RequestTimeoutPolicyName)); + source.Append(configuration.RequestTimeoutPolicyName.ToStringLiteral()); source.Append(')'); } else @@ -1153,127 +986,6 @@ private static StringBuilder GetUseEndpointHandlersStringBuilder(ImmutableArray< return StringBuilderPool.Get((int)Math.Max(baseSize, estimate)); } - [SuppressMessage("Globalization", "CA1308: Normalize strings to uppercase", Justification = "C# boolean literals must be lowercase.")] - internal static string ConstLiteral(TypedConstant tc) - { - if (tc.IsNull) - return "null"; - var v = tc.Value; - var t = tc.Type; - if (t is null) - return "null"; - - if (t.TypeKind != TypeKind.Enum) - return t.SpecialType switch - { - SpecialType.System_String => StringLiteral((string?)v), - SpecialType.System_Char => $"'{EscapeChar((char)v!)}'", - SpecialType.System_Boolean => ((bool)v!).ToString() - .ToLowerInvariant(), - SpecialType.System_Double => ((double)v!).ToString("R", CultureInfo.InvariantCulture), - SpecialType.System_Single => ((float)v!).ToString("R", CultureInfo.InvariantCulture) + "f", - SpecialType.System_Decimal => ((decimal)v!).ToString(CultureInfo.InvariantCulture) + "m", - SpecialType.System_SByte => ((sbyte)v!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Byte => ((byte)v!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Int16 => ((short)v!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_UInt16 => ((ushort)v!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Int32 => ((int)v!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_UInt32 => ((uint)v!).ToString(CultureInfo.InvariantCulture) + "u", - SpecialType.System_Int64 => ((long)v!).ToString(CultureInfo.InvariantCulture) + "L", - SpecialType.System_UInt64 => ((ulong)v!).ToString(CultureInfo.InvariantCulture) + "UL", - _ => StringLiteral(v?.ToString()), - }; - - var field = t.GetMembers() - .OfType() - .FirstOrDefault(f => f.HasConstantValue && Equals(f.ConstantValue, v)); - - if (field is not null) - return $"{t.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}.{field.Name}"; - - var underlying = ((INamedTypeSymbol)t).EnumUnderlyingType!; - var num = IntegralLiteral(v, underlying.SpecialType); - return $"({t.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}){num}"; - } - - private static string IntegralLiteral(object? value, SpecialType underlying) - { - return underlying switch - { - SpecialType.System_SByte => ((sbyte)value!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Byte => ((byte)value!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Int16 => ((short)value!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_UInt16 => ((ushort)value!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Int32 => ((int)value!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_UInt32 => ((uint)value!).ToString(CultureInfo.InvariantCulture) + "u", - SpecialType.System_Int64 => ((long)value!).ToString(CultureInfo.InvariantCulture) + "L", - SpecialType.System_UInt64 => ((ulong)value!).ToString(CultureInfo.InvariantCulture) + "UL", - _ => "0", - }; - } - - internal static string StringLiteral(string? value) - { - if (value is null) - return "null"; - - var firstEscapeIndex = -1; - for (var i = 0; i < value.Length; i++) - { - var c = value[i]; - if (c == '\"' || c == '\\' || c == '\n' || c == '\r' || c == '\t' || c == '\0' || char.IsControl(c)) - { - firstEscapeIndex = i; - break; - } - } - - if (firstEscapeIndex < 0) - return string.Concat("\"", value, "\""); - - var sb = StringBuilderPool.Get(value.Length + 2); - sb.Append('"'); - if (firstEscapeIndex > 0) - sb.Append(value, 0, firstEscapeIndex); - - for (var i = firstEscapeIndex; i < value.Length; i++) - { - var c = value[i]; - switch (c) - { - case '\"': - sb.Append("\\\""); - break; - case '\\': - sb.Append("\\\\"); - break; - case '\n': - sb.Append("\\n"); - break; - case '\r': - sb.Append("\\r"); - break; - case '\t': - sb.Append("\\t"); - break; - case '\0': - sb.Append("\\0"); - break; - default: - if (char.IsControl(c)) - sb.Append("\\u") - .Append(((int)c).ToString("x4", CultureInfo.InvariantCulture)); - else - sb.Append(c); - - break; - } - } - - sb.Append('"'); - return StringBuilderPool.ToStringAndReturn(sb); - } - private static void AppendAdditionalContentTypes(StringBuilder source, EquatableImmutableArray? additionalContentTypes) { if (additionalContentTypes is not { Count: > 0 }) @@ -1282,7 +994,7 @@ private static void AppendAdditionalContentTypes(StringBuilder source, Equatable foreach (var additional in additionalContentTypes.Value) { source.Append(", "); - source.Append(StringLiteral(additional)); + source.Append(additional.ToStringLiteral()); } } @@ -1291,11 +1003,11 @@ private static void AppendCommaSeparatedLiterals(StringBuilder source, Equatable if (values.Count == 0) return; - source.Append(StringLiteral(values[0])); + source.Append(values[0].ToStringLiteral()); for (var i = 1; i < values.Count; i++) { source.Append(", "); - source.Append(StringLiteral(values[i])); + source.Append(values[i].ToStringLiteral()); } } @@ -1305,22 +1017,7 @@ private static void AppendOptionalContentTypes(StringBuilder source, string? con return; source.Append(", "); - source.Append(contentType is { Length: > 0 } ? StringLiteral(contentType) : "null"); + source.Append(contentType is { Length: > 0 } ? contentType.ToStringLiteral() : "null"); AppendAdditionalContentTypes(source, additionalContentTypes); } - - private static string EscapeChar(char c) - { - return c switch - { - '\'' => "\\'", - '\\' => "\\\\", - '\n' => "\\n", - '\r' => "\\r", - '\t' => "\\t", - '\0' => "\\0", - _ when char.IsControl(c) => "\\u" + ((int)c).ToString("x4", CultureInfo.InvariantCulture), - _ => c.ToString(), - }; - } } From c61c97657a8370564e494ad2289d7645e258b413 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 18:33:19 -0500 Subject: [PATCH 03/32] Further refactor. --- .../Common/EndpointConfigurationFactory.cs | 75 +++++++++--- .../Common/RequestHandlerClassCacheEntry.cs | 4 +- .../Common/RequestHandlerMethod.cs | 1 - .../Common/StringExtensions.cs | 5 + src/GeneratedEndpoints/MinimalApiGenerator.cs | 115 +++++------------- .../GetUserEndpoint.cs | 5 +- ...dVariants_MapEndpointHandlers.verified.txt | 4 +- 7 files changed, 100 insertions(+), 109 deletions(-) diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index df02397..ef17ae2 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -1,7 +1,7 @@ using System.Collections.Immutable; using System.Runtime.CompilerServices; using Microsoft.CodeAnalysis; -using static GeneratedEndpoints.Common.AttributeSymbolMatcher; +using System.ComponentModel; using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints.Common; @@ -16,13 +16,16 @@ private sealed class GeneratedAttributeKindCacheEntry(GeneratedAttributeKind kin private static readonly ConditionalWeakTable GeneratedAttributeKindCache = new(); public static EndpointConfiguration Create( - ImmutableArray attributes, + ISymbol methodSymbol, string? name, - string? displayName, - string? description, bool enforceMethodRequireAuthorizationRules ) { + var attributes = methodSymbol.GetAttributes(); + var (displayName, description) = GetDisplayAndDescriptionAttributes(methodSymbol); + + name ??= RemoveAsyncSuffix(methodSymbol.Name); + var state = new EndpointAttributeState(); PopulateAttributeState(attributes, ref state); @@ -67,6 +70,43 @@ bool enforceMethodRequireAuthorizationRules ); } + private static string RemoveAsyncSuffix(string methodName) + { + if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) + return methodName[..^AsyncSuffix.Length]; + + return methodName; + } + + private static (string? DisplayName, string? Description) GetDisplayAndDescriptionAttributes(ISymbol methodSymbol) + { + string? displayName = null; + string? description = null; + + foreach (var attribute in methodSymbol.GetAttributes()) + { + var attributeClass = attribute.AttributeClass; + if (attributeClass is null) + continue; + + if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DisplayNameAttribute), ComponentModelNamespaceParts)) + { + displayName = attribute.ConstructorArguments.Length > 0 + ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() + : null; + + continue; + } + + if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DescriptionAttribute), ComponentModelNamespaceParts)) + description = attribute.ConstructorArguments.Length > 0 + ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() + : null; + } + + return (displayName, description); + } + public static GeneratedAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) { var definition = attributeClass.OriginalDefinition; @@ -78,11 +118,6 @@ public static GeneratedAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol return cacheEntry.Kind; } - public static string? NormalizeOptionalString(string? value) - { - return string.IsNullOrWhiteSpace(value) ? null : value!.Trim(); - } - private static void PopulateAttributeState(ImmutableArray attributes, ref EndpointAttributeState state) { ref var tags = ref state.Tags; @@ -142,7 +177,7 @@ private static void PopulateAttributeState(ImmutableArray attribu policyName = attribute.ConstructorArguments[0].Value as string; policyName ??= GetNamedStringValue(attribute, PolicyNameAttributeNamedParameter); - requestTimeoutPolicyName = NormalizeOptionalString(policyName); + requestTimeoutPolicyName = policyName.NormalizeOptionalString(); continue; } case GeneratedAttributeKind.Order: @@ -159,7 +194,7 @@ private static void PopulateAttributeState(ImmutableArray attribu case GeneratedAttributeKind.Summary: if (attribute.ConstructorArguments.Length > 0) { - var summaryValue = NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string); + var summaryValue = (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString(); if (!string.IsNullOrEmpty(summaryValue)) summary = summaryValue; } @@ -183,7 +218,7 @@ private static void PopulateAttributeState(ImmutableArray attribu case GeneratedAttributeKind.RequireCors: requireCors = true; corsPolicyName = attribute.ConstructorArguments.Length > 0 - ? NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string) + ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() : null; continue; case GeneratedAttributeKind.RequireHost: @@ -200,7 +235,7 @@ private static void PopulateAttributeState(ImmutableArray attribu case GeneratedAttributeKind.RequireRateLimiting: { var policyName = attribute.ConstructorArguments.Length > 0 - ? NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string) + ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() : null; if (!string.IsNullOrEmpty(policyName)) @@ -248,14 +283,14 @@ private static void PopulateAttributeState(ImmutableArray attribu } } - if (IsAttribute(attributeClass, AllowAnonymousAttributeName, AspNetCoreAuthorizationNamespaceParts)) + if (AttributeSymbolMatcher.IsAttribute(attributeClass, AllowAnonymousAttributeName, AspNetCoreAuthorizationNamespaceParts)) { allowAnonymous = true; hasAllowAnonymousAttribute = true; continue; } - if (IsAttribute(attributeClass, "TagsAttribute", AspNetCoreHttpNamespaceParts)) + if (AttributeSymbolMatcher.IsAttribute(attributeClass, "TagsAttribute", AspNetCoreHttpNamespaceParts)) { if (attribute.ConstructorArguments.Length > 0) { @@ -266,7 +301,7 @@ private static void PopulateAttributeState(ImmutableArray attribu continue; } - if (IsAttribute(attributeClass, "ExcludeFromDescriptionAttribute", AspNetCoreRoutingNamespaceParts)) + if (AttributeSymbolMatcher.IsAttribute(attributeClass, "ExcludeFromDescriptionAttribute", AspNetCoreRoutingNamespaceParts)) excludeFromDescription = true; } } @@ -288,7 +323,7 @@ private static void MergeInto(ref EquatableImmutableArray? target, Immut if (value.Value is not string stringValue) continue; - var trimmed = NormalizeOptionalString(stringValue); + var trimmed = stringValue.NormalizeOptionalString(); if (trimmed is not { Length: > 0 }) continue; @@ -315,7 +350,7 @@ internal static EquatableImmutableArray MergeUnion(EquatableImmutableArr foreach (var value in values) { - var normalized = NormalizeOptionalString(value); + var normalized = value.NormalizeOptionalString(); if (normalized is not { Length: > 0 }) continue; @@ -494,7 +529,7 @@ private static bool GetNamedBoolValue(AttributeData attribute, string namedParam foreach (var namedArg in attribute.NamedArguments) { if (namedArg.Key == namedParameter && namedArg.Value.Value is string stringValue) - return NormalizeOptionalString(stringValue); + return stringValue.NormalizeOptionalString(); } return null; @@ -502,7 +537,7 @@ private static bool GetNamedBoolValue(AttributeData attribute, string namedParam private static GeneratedAttributeKind GetGeneratedAttributeKindCore(INamedTypeSymbol definition) { - if (!IsInNamespace(definition.ContainingNamespace, AttributesNamespaceParts)) + if (!AttributeSymbolMatcher.IsInNamespace(definition.ContainingNamespace, AttributesNamespaceParts)) return GeneratedAttributeKind.None; return definition.Name switch diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs index 647cc95..a2cb367 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs @@ -30,7 +30,7 @@ public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, Compilation var mapGroupPattern = GetMapGroupPattern(classSymbol); var mapGroupIdentifier = mapGroupPattern is null ? null : GetMapGroupIdentifier(name); - var classConfiguration = EndpointConfigurationFactory.Create(classSymbol.GetAttributes(), null, null, null, false); + EndpointConfiguration classConfiguration = EndpointConfigurationFactory.Create(classSymbol, null, false); _value = new RequestHandlerClass(name, isStatic, configureMethodDetails.HasConfigureMethod, configureMethodDetails.ConfigureMethodAcceptsServiceProvider, mapGroupPattern, mapGroupIdentifier, classConfiguration @@ -202,7 +202,7 @@ private static bool MatchesServiceProvider(ITypeSymbol typeSymbol) private static string GetMapGroupIdentifier(string className) { if (className.StartsWith(GlobalPrefix, StringComparison.Ordinal)) - className = className.Substring(GlobalPrefix.Length); + className = className[GlobalPrefix.Length..]; var builder = StringBuilderPool.Get(className.Length + 8); builder.Append('_'); diff --git a/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs index d6efe14..55aeba8 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs @@ -3,6 +3,5 @@ namespace GeneratedEndpoints.Common; internal readonly record struct RequestHandlerMethod( string Name, bool IsStatic, - bool IsAwaitable, EquatableImmutableArray Parameters ); diff --git a/src/GeneratedEndpoints/Common/StringExtensions.cs b/src/GeneratedEndpoints/Common/StringExtensions.cs index cff861b..0c08744 100644 --- a/src/GeneratedEndpoints/Common/StringExtensions.cs +++ b/src/GeneratedEndpoints/Common/StringExtensions.cs @@ -65,4 +65,9 @@ public static string ToStringLiteral(this string? value) sb.Append('"'); return StringBuilderPool.ToStringAndReturn(sb); } + + public static string? NormalizeOptionalString(this string? value) + { + return string.IsNullOrWhiteSpace(value) ? null : value!.Trim(); + } } diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 1961ede..564b99f 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -1,7 +1,5 @@ using System.Buffers; using System.Collections.Immutable; -using System.ComponentModel; -using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; using System.Text; using GeneratedEndpoints.Common; @@ -12,6 +10,10 @@ namespace GeneratedEndpoints; +// ReSharper disable ForCanBeConvertedToForeach +// ReSharper disable LoopCanBeConvertedToQuery +// Do not refactor, use for loop to avoid allocations. + [Generator] public sealed class MinimalApiGenerator : IIncrementalGenerator { @@ -24,8 +26,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var requestHandlerProviders = ImmutableArray.CreateBuilder>>(HttpAttributeDefinitions.Length); - // ReSharper disable once ForCanBeConvertedToForeach - // Do not refactor, use for loop to avoid allocations. for (var index = 0; index < HttpAttributeDefinitions.Length; index++) { var definition = HttpAttributeDefinitions[index]; @@ -65,23 +65,23 @@ private static void RegisterAttributes(IncrementalGeneratorPostInitializationCon foreach (var definition in HttpAttributeDefinitions) context.AddSource(definition.Hint, definition.SourceText); - context.AddSource(RequireAuthorizationAttributeHint, RequireAuthorizationAttributeSourceText); - context.AddSource(RequireCorsAttributeHint, RequireCorsAttributeSourceText); - context.AddSource(RequireRateLimitingAttributeHint, RequireRateLimitingAttributeSourceText); - context.AddSource(RequireHostAttributeHint, RequireHostAttributeSourceText); + context.AddSource(AcceptsAttributeHint, AcceptsAttributeSourceText); context.AddSource(DisableAntiforgeryAttributeHint, DisableAntiforgeryAttributeSourceText); - context.AddSource(ShortCircuitAttributeHint, ShortCircuitAttributeSourceText); context.AddSource(DisableRequestTimeoutAttributeHint, DisableRequestTimeoutAttributeSourceText); context.AddSource(DisableValidationAttributeHint, DisableValidationAttributeSourceText); - context.AddSource(RequestTimeoutAttributeHint, RequestTimeoutAttributeSourceText); - context.AddSource(OrderAttributeHint, OrderAttributeSourceText); - context.AddSource(MapGroupAttributeHint, MapGroupAttributeSourceText); - context.AddSource(SummaryAttributeHint, SummaryAttributeSourceText); - context.AddSource(AcceptsAttributeHint, AcceptsAttributeSourceText); context.AddSource(EndpointFilterAttributeHint, EndpointFilterAttributeSourceText); - context.AddSource(ProducesResponseAttributeHint, ProducesResponseAttributeSourceText); + context.AddSource(MapGroupAttributeHint, MapGroupAttributeSourceText); + context.AddSource(OrderAttributeHint, OrderAttributeSourceText); context.AddSource(ProducesProblemAttributeHint, ProducesProblemAttributeSourceText); + context.AddSource(ProducesResponseAttributeHint, ProducesResponseAttributeSourceText); context.AddSource(ProducesValidationProblemAttributeHint, ProducesValidationProblemAttributeSourceText); + context.AddSource(RequestTimeoutAttributeHint, RequestTimeoutAttributeSourceText); + context.AddSource(RequireAuthorizationAttributeHint, RequireAuthorizationAttributeSourceText); + context.AddSource(RequireCorsAttributeHint, RequireCorsAttributeSourceText); + context.AddSource(RequireHostAttributeHint, RequireHostAttributeSourceText); + context.AddSource(RequireRateLimitingAttributeHint, RequireRateLimitingAttributeSourceText); + context.AddSource(ShortCircuitAttributeHint, ShortCircuitAttributeSourceText); + context.AddSource(SummaryAttributeHint, SummaryAttributeSourceText); } private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToken cancellationToken) @@ -107,25 +107,13 @@ private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToke var (httpMethod, pattern, name) = GetRequestHandlerAttribute(attribute, cancellationToken); - var (displayName, description) = GetDisplayAndDescriptionAttributes(requestHandlerMethodSymbol); - - name ??= RemoveAsyncSuffix(requestHandlerMethod.Name); - - var methodConfiguration = EndpointConfigurationFactory.Create(requestHandlerMethodSymbol.GetAttributes(), name, displayName, description, true); + var methodConfiguration = EndpointConfigurationFactory.Create(requestHandlerMethodSymbol, name, true); var requestHandler = new RequestHandler(requestHandlerClass.Value, requestHandlerMethod, httpMethod, pattern, methodConfiguration); return requestHandler; } - private static string RemoveAsyncSuffix(string methodName) - { - if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) - return methodName[..^AsyncSuffix.Length]; - - return methodName; - } - private static (string HttpMethod, string Pattern, string? Name) GetRequestHandlerAttribute(AttributeData attribute, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -137,8 +125,6 @@ private static (string HttpMethod, string Pattern, string? Name) GetRequestHandl var pattern = (attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : "") ?? ""; string? name = null; - // ReSharper disable once ForCanBeConvertedToForeach - // Do not refactor, use for loop to avoid allocations. for (var index = 0; index < attribute.NamedArguments.Length; index++) { var namedArg = attribute.NamedArguments[index]; @@ -156,47 +142,15 @@ private static (string HttpMethod, string Pattern, string? Name) GetRequestHandl return (httpMethod, pattern, name); } - private static (string? DisplayName, string? Description) GetDisplayAndDescriptionAttributes(IMethodSymbol methodSymbol) - { - string? displayName = null; - string? description = null; - - foreach (var attribute in methodSymbol.GetAttributes()) - { - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DisplayNameAttribute), ComponentModelNamespaceParts)) - { - displayName = EndpointConfigurationFactory.NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 - ? attribute.ConstructorArguments[0].Value as string - : null - ); - - continue; - } - - if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DescriptionAttribute), ComponentModelNamespaceParts)) - description = EndpointConfigurationFactory.NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 - ? attribute.ConstructorArguments[0].Value as string - : null - ); - } - - return (displayName, description); - } - private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol methodSymbol, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var name = methodSymbol.Name; var isStatic = methodSymbol.IsStatic; - var isAwaitable = methodSymbol.ReturnType.IsAwaitable(); var parameters = RequestHandlerParameterHelper.Build(methodSymbol, cancellationToken); - var requestHandlerMethod = new RequestHandlerMethod(name, isStatic, isAwaitable, parameters); + var requestHandlerMethod = new RequestHandlerMethod(name, isStatic, parameters); return requestHandlerMethod; } @@ -331,7 +285,7 @@ private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestH { var className = requestHandler.Class.Name; if (className.StartsWith(GlobalPrefix, StringComparison.Ordinal)) - className = className.Substring(GlobalPrefix.Length); + className = className[GlobalPrefix.Length..]; if (className.IndexOf('+') >= 0) className = className.Replace('+', '.'); @@ -389,9 +343,6 @@ private static void GenerateAddEndpointHandlersClass(SourceProductionContext con context.AddSource(AddEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); } - [SuppressMessage("Major Code Smell", "S3267:Loops should be simplified by calling the \"Select\" LINQ method", - Justification = "Manual loops avoid repeated allocations in the source generator." - )] private static List GetDistinctNonStaticClassNames(ImmutableArray requestHandlers) { var classNames = new List(); @@ -399,8 +350,9 @@ private static List GetDistinctNonStaticClassNames(ImmutableArray(StringComparer.Ordinal); - foreach (var requestHandler in requestHandlers) + for (var index = 0; index < requestHandlers.Length; index++) { + var requestHandler = requestHandlers[index]; if (requestHandler.Class.IsStatic) continue; @@ -415,8 +367,11 @@ private static List GetDistinctNonStaticClassNames(ImmutableArray nonStaticClassNames) { var estimate = 512L; - foreach (var className in nonStaticClassNames) + for (var index = 0; index < nonStaticClassNames.Count; index++) + { + var className = nonStaticClassNames[index]; estimate += 36 + className.Length; + } estimate += Math.Max(256, nonStaticClassNames.Count * 12); estimate = (long)(estimate * 1.10); @@ -467,8 +422,9 @@ private static void GenerateUseEndpointHandlersClass(SourceProductionContext con var groupedClasses = GetClassesWithMapGroups(requestHandlers); - foreach (var groupedClass in groupedClasses) + for (var index = 0; index < groupedClasses.Count; index++) { + var groupedClass = groupedClasses[index]; source.Append(" var "); source.Append(groupedClass.MapGroupBuilderIdentifier); source.Append(" = builder.MapGroup("); @@ -504,8 +460,9 @@ private static void GenerateUseEndpointHandlersClass(SourceProductionContext con private static bool HasRateLimitedHandlers(ImmutableArray requestHandlers) { - foreach (var handler in requestHandlers) + for (var index = 0; index < requestHandlers.Length; index++) { + var handler = requestHandlers[index]; if (handler.Configuration.RequireRateLimiting) return true; } @@ -513,9 +470,6 @@ private static bool HasRateLimitedHandlers(ImmutableArray reques return false; } - [SuppressMessage("Major Code Smell", "S3267:Loops should be simplified by calling the \"Select\" LINQ method", - Justification = "Manual loops avoid repeated allocations in the source generator." - )] private static List GetClassesWithMapGroups(ImmutableArray requestHandlers) { var groupedClasses = new List(); @@ -523,8 +477,9 @@ private static List GetClassesWithMapGroups(ImmutableArray< return groupedClasses; var seen = new HashSet(StringComparer.Ordinal); - foreach (var handler in requestHandlers) + for (var index = 0; index < requestHandlers.Length; index++) { + var handler = requestHandlers[index]; var handlerClass = handler.Class; if (handlerClass.MapGroupPattern is null) continue; @@ -590,10 +545,7 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl } else { - source.Append("static "); - if (requestHandler.Method.IsAwaitable) - source.Append("async "); - source.Append("([FromServices] "); + source.Append("static ([FromServices] "); source.Append(requestHandler.Class.Name); source.Append(" handler"); foreach (var parameter in requestHandler.Method.Parameters) @@ -604,10 +556,7 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl source.Append(' '); source.Append(parameter.Name); } - source.Append(") => "); - if (requestHandler.Method.IsAwaitable) - source.Append("await "); - source.Append("handler."); + source.Append(") => handler."); source.Append(requestHandler.Method.Name); source.Append('('); for (var index = 0; index < requestHandler.Method.Parameters.Count; index++) diff --git a/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs b/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs index 4452d06..fb6ff08 100644 --- a/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs +++ b/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Generated.Attributes; @@ -31,11 +32,13 @@ internal sealed class GetUserEndpoint(IServiceProvider serviceProvider) [Description("Gets a user by ID when the ID is greater than zero.")] [Summary("Gets a user by ID.")] [MapGet("/users/{id:int}", Name = nameof(GetUser))] - public Results, NotFound, ValidationProblem, ProblemHttpResult> GetUser( + public async ValueTask, NotFound, ValidationProblem, ProblemHttpResult>> GetUser( [FromQuery] int id, [FromKeyedServices(ServiceLifetime.Scoped)] IServiceCollection services ) { + await Task.Yield(); + if (id <= 0) { var errors = new Dictionary diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.AsyncMethodVariants_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/IndividualTests.AsyncMethodVariants_MapEndpointHandlers.verified.txt index 2d957e5..93fa041 100644 --- a/tests/GeneratedEndpoints.Tests/IndividualTests.AsyncMethodVariants_MapEndpointHandlers.verified.txt +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.AsyncMethodVariants_MapEndpointHandlers.verified.txt @@ -25,13 +25,13 @@ internal static class EndpointRouteBuilderExtensions builder.MapGet("/task", static ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler) => handler.TaskOnly()) .WithName("TaskOnly"); - builder.MapGet("/task-result", static async ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler, int id) => await handler.TaskWithResult(id)) + builder.MapGet("/task-result", static ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler, int id) => handler.TaskWithResult(id)) .WithName("TaskWithResult"); builder.MapPost("/valuetask", static ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler) => handler.ValueTaskOnly()) .WithName("ValueTaskOnly"); - builder.MapPost("/valuetask-result", static async ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler, int id) => await handler.ValueTaskWithResult(id)) + builder.MapPost("/valuetask-result", static ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler, int id) => handler.ValueTaskWithResult(id)) .WithName("ValueTaskWithResult"); return builder; From 72548fbc23c0cc010cf1c3d3e9928addcbc01ea4 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 20:01:16 -0500 Subject: [PATCH 04/32] More refactoring. --- .../Common/AttributeDataExtensions.cs | 40 ++++++ .../Common/EndpointConfigurationFactory.cs | 61 ++------- .../Common/RequestHandlerClassCacheEntry.cs | 4 +- .../Common/RequestHandlerParameterHelper.cs | 128 +++++++++--------- .../GetUserEndpoint.cs | 2 +- .../Common/SourceFactory.cs | 14 +- 6 files changed, 123 insertions(+), 126 deletions(-) create mode 100644 src/GeneratedEndpoints/Common/AttributeDataExtensions.cs diff --git a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs new file mode 100644 index 0000000..d3ddc4f --- /dev/null +++ b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs @@ -0,0 +1,40 @@ +using Microsoft.CodeAnalysis; + +namespace GeneratedEndpoints.Common; + +internal static class AttributeDataExtensions +{ + public static ITypeSymbol? GetNamedTypeSymbol(this AttributeData attribute, string namedParameter) + { + foreach (var namedArg in attribute.NamedArguments) + { + if (namedArg.Key == namedParameter && namedArg.Value.Value is ITypeSymbol typeSymbol) + return typeSymbol; + } + + return null; + } + + public static bool GetNamedBoolValue(this AttributeData attribute, string namedParameter, bool defaultValue = false) + { + foreach (var namedArg in attribute.NamedArguments) + { + if (namedArg.Key == namedParameter && namedArg.Value.Value is bool boolValue) + return boolValue; + } + + return defaultValue; + } + + public static string? GetNamedStringValue(this AttributeData attribute, string namedParameter) + { + foreach (var namedArg in attribute.NamedArguments) + { + if (namedArg.Key == namedParameter && namedArg.Value.Value is string stringValue) + return stringValue.NormalizeOptionalString(); + } + + return null; + } + +} diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index ef17ae2..1f0361b 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -171,13 +171,10 @@ private static void PopulateAttributeState(ImmutableArray attribu { disableRequestTimeout = false; withRequestTimeout = true; - string? policyName = null; if (attribute.ConstructorArguments.Length > 0) - policyName = attribute.ConstructorArguments[0].Value as string; - - policyName ??= GetNamedStringValue(attribute, PolicyNameAttributeNamedParameter); - requestTimeoutPolicyName = policyName.NormalizeOptionalString(); + policyName = (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString(); + requestTimeoutPolicyName = policyName; continue; } case GeneratedAttributeKind.Order: @@ -186,7 +183,7 @@ private static void PopulateAttributeState(ImmutableArray attribu continue; case GeneratedAttributeKind.MapGroup: { - var groupName = GetNamedStringValue(attribute, NameAttributeNamedParameter); + var groupName = attribute.GetNamedStringValue(NameAttributeNamedParameter); if (!string.IsNullOrEmpty(groupName)) endpointGroupName = groupName; continue; @@ -258,7 +255,7 @@ private static void PopulateAttributeState(ImmutableArray attribu ? producesProblemStatusCode : 500; var contentType = attribute.ConstructorArguments.Length > 1 - ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) + ? (attribute.ConstructorArguments[1].Value as string).NormalizeOptionalString() : null; var additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; @@ -273,7 +270,7 @@ private static void PopulateAttributeState(ImmutableArray attribu ? producesValidationProblemStatusCode : 400; var contentType = attribute.ConstructorArguments.Length > 1 - ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) + ? (attribute.ConstructorArguments[1].Value as string).NormalizeOptionalString() : null; var additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; @@ -375,11 +372,6 @@ private static string NormalizeRequiredContentType(string? contentType, string d return string.IsNullOrWhiteSpace(contentType) ? defaultValue : contentType!.Trim(); } - private static string? NormalizeOptionalContentType(string? contentType) - { - return string.IsNullOrWhiteSpace(contentType) ? null : contentType!.Trim(); - } - private static EquatableImmutableArray? GetStringArrayValues(TypedConstant typedConstant) { if (typedConstant.Kind != TypedConstantKind.Array || typedConstant.Values.IsDefaultOrEmpty) @@ -400,7 +392,7 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym string? requestType; string contentType; EquatableImmutableArray? additionalContentTypes; - var isOptional = GetNamedBoolValue(attribute, IsOptionalAttributeNamedParameter); + var isOptional = attribute.GetNamedBoolValue(IsOptionalAttributeNamedParameter); if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) { @@ -411,7 +403,7 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym : "application/json"; additionalContentTypes = attribute.ConstructorArguments.Length > 1 ? GetStringArrayValues(attribute.ConstructorArguments[1]) : null; } - else if (GetNamedTypeSymbol(attribute, RequestTypeAttributeNamedParameter) is { } requestTypeSymbol) + else if (attribute.GetNamedTypeSymbol(RequestTypeAttributeNamedParameter) is { } requestTypeSymbol) { requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); contentType = attribute.ConstructorArguments.Length > 0 @@ -442,16 +434,16 @@ private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSy statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode ? producesStatusCode : 200; - contentType = attribute.ConstructorArguments.Length > 1 ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) : null; + contentType = attribute.ConstructorArguments.Length > 1 ? (attribute.ConstructorArguments[1].Value as string).NormalizeOptionalString() : null; additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; } - else if (GetNamedTypeSymbol(attribute, ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol) + else if (attribute.GetNamedTypeSymbol(ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol) { responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode ? producesStatusCode : 200; - contentType = attribute.ConstructorArguments.Length > 1 ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) : null; + contentType = attribute.ConstructorArguments.Length > 1 ? (attribute.ConstructorArguments[1].Value as string).NormalizeOptionalString() : null; additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; } else @@ -502,39 +494,6 @@ private static void TryAddEndpointFilterType( endpointFilters.Add(displayString); } - private static ITypeSymbol? GetNamedTypeSymbol(AttributeData attribute, string namedParameter) - { - foreach (var namedArg in attribute.NamedArguments) - { - if (namedArg.Key == namedParameter && namedArg.Value.Value is ITypeSymbol typeSymbol) - return typeSymbol; - } - - return null; - } - - private static bool GetNamedBoolValue(AttributeData attribute, string namedParameter, bool defaultValue = false) - { - foreach (var namedArg in attribute.NamedArguments) - { - if (namedArg.Key == namedParameter && namedArg.Value.Value is bool boolValue) - return boolValue; - } - - return defaultValue; - } - - private static string? GetNamedStringValue(AttributeData attribute, string namedParameter) - { - foreach (var namedArg in attribute.NamedArguments) - { - if (namedArg.Key == namedParameter && namedArg.Value.Value is string stringValue) - return stringValue.NormalizeOptionalString(); - } - - return null; - } - private static GeneratedAttributeKind GetGeneratedAttributeKindCore(INamedTypeSymbol definition) { if (!AttributeSymbolMatcher.IsInNamespace(definition.ContainingNamespace, AttributesNamespaceParts)) diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs index a2cb367..d79a332 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs @@ -165,7 +165,7 @@ private static bool MatchesEndpointConventionBuilder(ITypeSymbol typeSymbol) if (!string.Equals(namedType.Name, "IEndpointConventionBuilder", StringComparison.Ordinal)) return false; - var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? string.Empty; + var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? ""; return string.Equals(containingNamespace, "Microsoft.AspNetCore.Builder", StringComparison.Ordinal); } @@ -177,7 +177,7 @@ private static bool MatchesServiceProvider(ITypeSymbol typeSymbol) if (!string.Equals(namedType.Name, "IServiceProvider", StringComparison.Ordinal)) return false; - var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? string.Empty; + var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? ""; return string.Equals(containingNamespace, "System", StringComparison.Ordinal); } diff --git a/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs b/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs index 45f5601..7a4643d 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs @@ -5,6 +5,10 @@ namespace GeneratedEndpoints.Common; +// ReSharper disable ForCanBeConvertedToForeach +// ReSharper disable LoopCanBeConvertedToQuery +// Do not refactor, use for loop to avoid allocations. + internal static class RequestHandlerParameterHelper { public static EquatableImmutableArray Build(IMethodSymbol methodSymbol, CancellationToken cancellationToken) @@ -12,74 +16,57 @@ public static EquatableImmutableArray Build(IMethodSymbol methodSymbo cancellationToken.ThrowIfCancellationRequested(); var methodParameters = ImmutableArray.CreateBuilder(methodSymbol.Parameters.Length); - foreach (var parameter in methodSymbol.Parameters) + + for (var index = 0; index < methodSymbol.Parameters.Length; index++) { cancellationToken.ThrowIfCancellationRequested(); - var source = BindingSource.None; - TypedConstant? typedKey = null; - string? bindingName = null; + var parameterSymbol = methodSymbol.Parameters[index]; + var parameterName = parameterSymbol.Name; + var parameterType = parameterSymbol.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var bindingPrefix = GetBindingPrefix(parameterSymbol); + var parameter = new Parameter(parameterName, parameterType, bindingPrefix); - foreach (var attribute in parameter.GetAttributes()) - { - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - var attributeSource = GetBindingSourceFromAttributeClass(attributeClass); - if (attributeSource == BindingSource.None) - continue; - - source = attributeSource; - switch (attributeSource) - { - case BindingSource.FromRoute: - case BindingSource.FromQuery: - case BindingSource.FromHeader: - case BindingSource.FromForm: - bindingName = GetBindingAttributeName(attribute) ?? bindingName; - break; - case BindingSource.FromKeyedServices: - typedKey = attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0] : null; - break; - } - } - - var parameterName = parameter.Name; - var parameterType = parameter.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var key = typedKey?.ToConstLiteral(); - var bindingPrefix = GetBindingSourceAttribute(source, key, bindingName); - methodParameters.Add(new Parameter(parameterName, parameterType, bindingPrefix)); + methodParameters.Add(parameter); } return methodParameters.ToEquatableImmutable(); } - private static string? GetBindingAttributeName(AttributeData attribute) + private static string GetBindingPrefix(IParameterSymbol parameter) { - foreach (var namedArg in attribute.NamedArguments) + var source = BindingSource.None; + TypedConstant? typedKey = null; + string? bindingName = null; + + foreach (var attribute in parameter.GetAttributes()) { - if (string.Equals(namedArg.Key, NameAttributeNamedParameter, StringComparison.Ordinal) && namedArg.Value.Value is string namedValue) + var attributeClass = attribute.AttributeClass; + if (attributeClass is null) + continue; + + var attributeSource = GetBindingSourceFromAttributeClass(attributeClass); + if (attributeSource == BindingSource.None) + continue; + + source = attributeSource; + switch (attributeSource) { - var normalized = NormalizeBindingName(namedValue); - if (normalized is not null) - return normalized; + case BindingSource.FromRoute: + case BindingSource.FromQuery: + case BindingSource.FromHeader: + case BindingSource.FromForm: + bindingName = attribute.GetNamedStringValue(NameAttributeNamedParameter); + break; + case BindingSource.FromKeyedServices: + typedKey = attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0] : null; + break; } } - if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is string constructorName) - return NormalizeBindingName(constructorName); + var bindingPrefix = GetBindingSourceAttribute(source, typedKey, bindingName); - return null; - } - - private static string? NormalizeBindingName(string? value) - { - if (string.IsNullOrWhiteSpace(value)) - return null; - - var trimmed = value!.Trim(); - return trimmed.Length > 0 ? trimmed : null; + return bindingPrefix; } private static BindingSource GetBindingSourceFromAttributeClass(INamedTypeSymbol attributeClass) @@ -102,21 +89,32 @@ private static BindingSource GetBindingSourceFromAttributeClass(INamedTypeSymbol }; } - private static string GetBindingSourceAttribute(BindingSource source, string? key, string? bindingName) + private static string GetBindingSourceAttribute(BindingSource source, TypedConstant? typedKey, string? bindingName) { - return source switch + switch (source) { - BindingSource.None => string.Empty, - BindingSource.FromRoute => FormatBindingAttribute("FromRoute", bindingName), - BindingSource.FromQuery => FormatBindingAttribute("FromQuery", bindingName), - BindingSource.FromHeader => FormatBindingAttribute("FromHeader", bindingName), - BindingSource.FromBody => FormatBindingAttribute("FromBody", bindingName), - BindingSource.FromForm => FormatBindingAttribute("FromForm", bindingName), - BindingSource.FromServices => "[FromServices] ", - BindingSource.FromKeyedServices => $"[FromKeyedServices({key})] ", - BindingSource.AsParameters => "[AsParameters] ", - _ => string.Empty, - }; + case BindingSource.None: + return ""; + case BindingSource.FromRoute: + return FormatBindingAttribute("FromRoute", bindingName); + case BindingSource.FromQuery: + return FormatBindingAttribute("FromQuery", bindingName); + case BindingSource.FromHeader: + return FormatBindingAttribute("FromHeader", bindingName); + case BindingSource.FromBody: + return FormatBindingAttribute("FromBody", bindingName); + case BindingSource.FromForm: + return FormatBindingAttribute("FromForm", bindingName); + case BindingSource.FromServices: + return "[FromServices] "; + case BindingSource.FromKeyedServices: + var key = typedKey?.ToConstLiteral(); + return $"[FromKeyedServices({key})] "; + case BindingSource.AsParameters: + return "[AsParameters] "; + default: + return ""; + } } private static string FormatBindingAttribute(string attributeName, string? bindingName) diff --git a/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs b/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs index fb6ff08..31fd2cc 100644 --- a/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs +++ b/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs @@ -33,7 +33,7 @@ internal sealed class GetUserEndpoint(IServiceProvider serviceProvider) [Summary("Gets a user by ID.")] [MapGet("/users/{id:int}", Name = nameof(GetUser))] public async ValueTask, NotFound, ValidationProblem, ProblemHttpResult>> GetUser( - [FromQuery] int id, + [FromHeader(Name = "4")] int id, [FromKeyedServices(ServiceLifetime.Scoped)] IServiceCollection services ) { diff --git a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs index ac721c9..2e25ff4 100644 --- a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs +++ b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs @@ -85,7 +85,7 @@ public static string BuildAuthorizationMatrixSource( if (!string.IsNullOrWhiteSpace(groupName) && mapGroupPattern is null) { - mapGroupPattern = string.Empty; + mapGroupPattern = ""; } if (mapGroupPattern is not null) @@ -114,7 +114,7 @@ public static string BuildAuthorizationMatrixSource( if (applyRequestTimeout) { var timeoutArgument = string.IsNullOrWhiteSpace(requestTimeoutPolicy) - ? string.Empty + ? "" : $"(\"{requestTimeoutPolicy}\")"; builder.AppendLine($"[RequestTimeout{timeoutArgument}]"); } @@ -160,13 +160,13 @@ public static string BuildAuthorizationMatrixSource( if (methodRequireCors) { - var methodCors = string.IsNullOrWhiteSpace(methodCorsPolicy) ? string.Empty : $"(\"{methodCorsPolicy}\")"; + var methodCors = string.IsNullOrWhiteSpace(methodCorsPolicy) ? "" : $"(\"{methodCorsPolicy}\")"; builder.AppendLine($" [RequireCors{methodCors}]"); } if (requireRateLimiting) { - var rateLimit = string.IsNullOrWhiteSpace(rateLimitingPolicy) ? string.Empty : $"(\"{rateLimitingPolicy}\")"; + var rateLimit = string.IsNullOrWhiteSpace(rateLimitingPolicy) ? "" : $"(\"{rateLimitingPolicy}\")"; builder.AppendLine($" [RequireRateLimiting{rateLimit}]"); } @@ -223,7 +223,7 @@ public static string BuildConfigureAndFiltersSource( builder.AppendLine(" public static Ok Handle() => TypedResults.Ok();"); builder.AppendLine(); - builder.AppendLine(" public static void Configure(TBuilder builder" + (configureWithServiceProvider ? ", IServiceProvider services" : string.Empty) + ")"); + builder.AppendLine(" public static void Configure(TBuilder builder" + (configureWithServiceProvider ? ", IServiceProvider services" : "") + ")"); builder.AppendLine(" where TBuilder : IEndpointConventionBuilder"); builder.AppendLine(" {"); builder.AppendLine(" _ = builder;"); @@ -406,7 +406,7 @@ public static string BuildContractsAndBindingSource( if (includeAccepts) { - var secondContentType = string.IsNullOrWhiteSpace(acceptsContentType2) ? string.Empty : $", \"{acceptsContentType2}\""; + var secondContentType = string.IsNullOrWhiteSpace(acceptsContentType2) ? "" : $", \"{acceptsContentType2}\""; builder.AppendLine($" [Accepts(\"{acceptsContentType1 ?? "application/json"}\"{secondContentType})]"); } @@ -417,7 +417,7 @@ public static string BuildContractsAndBindingSource( if (includeProducesResponse) { - var secondProduces = string.IsNullOrWhiteSpace(producesContentType2) ? string.Empty : $", \"{producesContentType2}\""; + var secondProduces = string.IsNullOrWhiteSpace(producesContentType2) ? "" : $", \"{producesContentType2}\""; builder.AppendLine($" [ProducesResponse(200, \"{producesContentType1 ?? "application/json"}\"{secondProduces}, ResponseType = typeof(ResponseRecord))]"); } From 7ab04731c3d65ffa2847562c4537cf946d1059b1 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 20:15:15 -0500 Subject: [PATCH 05/32] Refactor. --- .../Common/EndpointConfiguration.cs | 11 +- .../Common/EndpointConfigurationFactory.cs | 153 +++++++----------- .../Common/RequestHandlerClassCacheEntry.cs | 2 +- .../Common/RequestHandlerMetadata.cs | 14 -- src/GeneratedEndpoints/MinimalApiGenerator.cs | 89 +++++----- 5 files changed, 111 insertions(+), 158 deletions(-) delete mode 100644 src/GeneratedEndpoints/Common/RequestHandlerMetadata.cs diff --git a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs index 7ffd09a..8289fc6 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs @@ -1,7 +1,16 @@ namespace GeneratedEndpoints.Common; internal readonly record struct EndpointConfiguration( - RequestHandlerMetadata Metadata, + string? Name, + string? DisplayName, + string? Summary, + string? Description, + EquatableImmutableArray? Tags, + EquatableImmutableArray? Accepts, + EquatableImmutableArray? Produces, + EquatableImmutableArray? ProducesProblem, + EquatableImmutableArray? ProducesValidationProblem, + bool ExcludeFromDescription, bool RequireAuthorization, EquatableImmutableArray? AuthorizationPolicies, bool DisableAntiforgery, diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index 1f0361b..23eb87d 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -1,25 +1,16 @@ using System.Collections.Immutable; +using System.ComponentModel; using System.Runtime.CompilerServices; using Microsoft.CodeAnalysis; -using System.ComponentModel; using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints.Common; internal static class EndpointConfigurationFactory { - private sealed class GeneratedAttributeKindCacheEntry(GeneratedAttributeKind kind) - { - public GeneratedAttributeKind Kind { get; } = kind; - } - private static readonly ConditionalWeakTable GeneratedAttributeKindCache = new(); - public static EndpointConfiguration Create( - ISymbol methodSymbol, - string? name, - bool enforceMethodRequireAuthorizationRules - ) + public static EndpointConfiguration Create(ISymbol methodSymbol, string? name, bool enforceMethodRequireAuthorizationRules) { var attributes = methodSymbol.GetAttributes(); var (displayName, description) = GetDisplayAndDescriptionAttributes(methodSymbol); @@ -32,42 +23,56 @@ bool enforceMethodRequireAuthorizationRules if (enforceMethodRequireAuthorizationRules && state is { HasRequireAuthorizationAttribute: true, HasAllowAnonymousAttribute: false }) state.AllowAnonymous = false; - var metadata = new RequestHandlerMetadata( - name, - displayName, - state.Summary, - description, - state.Tags, - ToEquatableOrNull(state.Accepts), - ToEquatableOrNull(state.Produces), - ToEquatableOrNull(state.ProducesProblem), - ToEquatableOrNull(state.ProducesValidationProblem), - state.ExcludeFromDescription ?? false - ); - var withRequestTimeout = state.WithRequestTimeout ?? false; var requestTimeoutPolicyName = withRequestTimeout ? state.RequestTimeoutPolicyName : null; - return new EndpointConfiguration( - metadata, - state.RequireAuthorization ?? false, - state.AuthorizationPolicies, - state.DisableAntiforgery ?? false, - state.AllowAnonymous ?? false, - state.RequireCors ?? false, - state.CorsPolicyName, - state.RequiredHosts, - state.RequireRateLimiting ?? false, - state.RateLimitingPolicyName, - ToEquatableOrNull(state.EndpointFilters), - state.ShortCircuit ?? false, - state.DisableValidation ?? false, - state.DisableRequestTimeout ?? false, - withRequestTimeout, - requestTimeoutPolicyName, - state.Order, - state.EndpointGroupName + return new EndpointConfiguration(name, displayName, state.Summary, description, state.Tags, ToEquatableOrNull(state.Accepts), + ToEquatableOrNull(state.Produces), ToEquatableOrNull(state.ProducesProblem), ToEquatableOrNull(state.ProducesValidationProblem), + state.ExcludeFromDescription ?? false, state.RequireAuthorization ?? false, state.AuthorizationPolicies, state.DisableAntiforgery ?? false, + state.AllowAnonymous ?? false, state.RequireCors ?? false, state.CorsPolicyName, state.RequiredHosts, state.RequireRateLimiting ?? false, + state.RateLimitingPolicyName, ToEquatableOrNull(state.EndpointFilters), state.ShortCircuit ?? false, state.DisableValidation ?? false, + state.DisableRequestTimeout ?? false, withRequestTimeout, requestTimeoutPolicyName, state.Order, state.EndpointGroupName + ); + } + + public static GeneratedAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) + { + var definition = attributeClass.OriginalDefinition; + var cacheEntry = GeneratedAttributeKindCache.GetValue( + definition, static def => new GeneratedAttributeKindCacheEntry(GetGeneratedAttributeKindCore(def)) ); + + return cacheEntry.Kind; + } + + internal static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, IEnumerable values) + { + List? list = null; + HashSet? seen = null; + + if (existing is { Count: > 0 }) + { + var count = existing.Value.Count; + list = new List(count + 4); + list.AddRange(existing.Value); + seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); + } + + foreach (var value in values) + { + var normalized = value.NormalizeOptionalString(); + if (normalized is not { Length: > 0 }) + continue; + + seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); + if (!seen.Add(normalized)) + continue; + + list ??= []; + list.Add(normalized); + } + + return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; } private static string RemoveAsyncSuffix(string methodName) @@ -91,33 +96,18 @@ private static (string? DisplayName, string? Description) GetDisplayAndDescripti if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DisplayNameAttribute), ComponentModelNamespaceParts)) { - displayName = attribute.ConstructorArguments.Length > 0 - ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() - : null; + displayName = attribute.ConstructorArguments.Length > 0 ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() : null; continue; } if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DescriptionAttribute), ComponentModelNamespaceParts)) - description = attribute.ConstructorArguments.Length > 0 - ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() - : null; + description = attribute.ConstructorArguments.Length > 0 ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() : null; } return (displayName, description); } - public static GeneratedAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) - { - var definition = attributeClass.OriginalDefinition; - var cacheEntry = GeneratedAttributeKindCache.GetValue( - definition, - static def => new GeneratedAttributeKindCacheEntry(GetGeneratedAttributeKindCore(def)) - ); - - return cacheEntry.Kind; - } - private static void PopulateAttributeState(ImmutableArray attributes, ref EndpointAttributeState state) { ref var tags = ref state.Tags; @@ -332,36 +322,6 @@ private static void MergeInto(ref EquatableImmutableArray? target, Immut MergeInto(ref target, normalized); } - internal static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, IEnumerable values) - { - List? list = null; - HashSet? seen = null; - - if (existing is { Count: > 0 }) - { - var count = existing.Value.Count; - list = new List(count + 4); - list.AddRange(existing.Value); - seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); - } - - foreach (var value in values) - { - var normalized = value.NormalizeOptionalString(); - if (normalized is not { Length: > 0 }) - continue; - - seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); - if (!seen.Add(normalized)) - continue; - - list ??= []; - list.Add(normalized); - } - - return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; - } - private static EquatableImmutableArray? ToEquatableOrNull(List? values) { return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null; @@ -459,7 +419,8 @@ private static void TryAddEndpointFilter( AttributeData attribute, INamedTypeSymbol attributeClass, ref List? endpointFilters, - ref HashSet? endpointFilterSet) + ref HashSet? endpointFilterSet + ) { if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) { @@ -474,10 +435,7 @@ private static void TryAddEndpointFilter( TryAddEndpointFilterType(filterTypeSymbol, ref endpointFilters, ref endpointFilterSet); } - private static void TryAddEndpointFilterType( - ITypeSymbol? typeSymbol, - ref List? endpointFilters, - ref HashSet? endpointFilterSet) + private static void TryAddEndpointFilterType(ITypeSymbol? typeSymbol, ref List? endpointFilters, ref HashSet? endpointFilterSet) { if (typeSymbol is null or ITypeParameterSymbol or IErrorTypeSymbol) return; @@ -521,4 +479,9 @@ private static GeneratedAttributeKind GetGeneratedAttributeKindCore(INamedTypeSy _ => GeneratedAttributeKind.None, }; } + + private sealed class GeneratedAttributeKindCacheEntry(GeneratedAttributeKind kind) + { + public GeneratedAttributeKind Kind { get; } = kind; + } } diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs index d79a332..52b8250 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs @@ -30,7 +30,7 @@ public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, Compilation var mapGroupPattern = GetMapGroupPattern(classSymbol); var mapGroupIdentifier = mapGroupPattern is null ? null : GetMapGroupIdentifier(name); - EndpointConfiguration classConfiguration = EndpointConfigurationFactory.Create(classSymbol, null, false); + var classConfiguration = EndpointConfigurationFactory.Create(classSymbol, null, false); _value = new RequestHandlerClass(name, isStatic, configureMethodDetails.HasConfigureMethod, configureMethodDetails.ConfigureMethodAcceptsServiceProvider, mapGroupPattern, mapGroupIdentifier, classConfiguration diff --git a/src/GeneratedEndpoints/Common/RequestHandlerMetadata.cs b/src/GeneratedEndpoints/Common/RequestHandlerMetadata.cs deleted file mode 100644 index 8f76cf8..0000000 --- a/src/GeneratedEndpoints/Common/RequestHandlerMetadata.cs +++ /dev/null @@ -1,14 +0,0 @@ -namespace GeneratedEndpoints.Common; - -internal readonly record struct RequestHandlerMetadata( - string? Name, - string? DisplayName, - string? Summary, - string? Description, - EquatableImmutableArray? Tags, - EquatableImmutableArray? Accepts, - EquatableImmutableArray? Produces, - EquatableImmutableArray? ProducesProblem, - EquatableImmutableArray? ProducesValidationProblem, - bool ExcludeFromDescription -); diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 564b99f..e9be12b 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -42,7 +42,6 @@ public void Initialize(IncrementalGeneratorInitializationContext context) context.RegisterSourceOutput(requestHandlers, GenerateSource); } - private static IncrementalValueProvider> CombineRequestHandlers( ImmutableArray>> handlerProviders ) @@ -206,15 +205,10 @@ private static ImmutableArray EnsureUniqueEndpointNames(Immutabl foreach (var index in collidingHandlers) { var handler = builder[index]; - var configuration = handler.Configuration; - var metadata = configuration.Metadata with + var configuration = handler.Configuration with { Name = GetFullyQualifiedMethodDisplayName(handler), }; - configuration = configuration with - { - Metadata = metadata, - }; builder[index] = handler with { Configuration = configuration, @@ -240,7 +234,7 @@ private static ImmutableArray GetRequestHandlersWithNameCollisions(Immutabl for (var index = 0; index < handlerCount; index++) { var handler = requestHandlers[index]; - var name = handler.Configuration.Metadata.Name; + var name = handler.Configuration.Name; if (string.IsNullOrEmpty(name)) continue; @@ -276,7 +270,7 @@ void MarkCollision(int handlerIndex) return; collisionFlags[handlerIndex] = true; - collidingIndices ??= new List(); + collidingIndices ??= []; collidingIndices.Add(handlerIndex); } } @@ -597,41 +591,39 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl private static void AppendEndpointConfiguration(StringBuilder source, string indent, EndpointConfiguration configuration, bool includeNameAndDisplayName) { - var metadata = configuration.Metadata; - - if (includeNameAndDisplayName && !string.IsNullOrEmpty(metadata.Name)) + if (includeNameAndDisplayName && !string.IsNullOrEmpty(configuration.Name)) { source.AppendLine(); source.Append(indent); source.Append(".WithName("); - source.Append(metadata.Name.ToStringLiteral()); + source.Append(configuration.Name.ToStringLiteral()); source.Append(')'); } - if (includeNameAndDisplayName && !string.IsNullOrEmpty(metadata.DisplayName)) + if (includeNameAndDisplayName && !string.IsNullOrEmpty(configuration.DisplayName)) { source.AppendLine(); source.Append(indent); source.Append(".WithDisplayName("); - source.Append(metadata.DisplayName.ToStringLiteral()); + source.Append(configuration.DisplayName.ToStringLiteral()); source.Append(')'); } - if (!string.IsNullOrEmpty(metadata.Summary)) + if (!string.IsNullOrEmpty(configuration.Summary)) { source.AppendLine(); source.Append(indent); source.Append(".WithSummary("); - source.Append(metadata.Summary.ToStringLiteral()); + source.Append(configuration.Summary.ToStringLiteral()); source.Append(')'); } - if (!string.IsNullOrEmpty(metadata.Description)) + if (!string.IsNullOrEmpty(configuration.Description)) { source.AppendLine(); source.Append(indent); source.Append(".WithDescription("); - source.Append(metadata.Description.ToStringLiteral()); + source.Append(configuration.Description.ToStringLiteral()); source.Append(')'); } @@ -653,24 +645,24 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.Append(')'); } - if (metadata.ExcludeFromDescription) + if (configuration.ExcludeFromDescription) { source.AppendLine(); source.Append(indent); source.Append(".ExcludeFromDescription()"); } - if (metadata.Tags is { Count: > 0 }) + if (configuration.Tags is { Count: > 0 }) { source.AppendLine(); source.Append(indent); source.Append(".WithTags("); - AppendCommaSeparatedLiterals(source, metadata.Tags.Value); + AppendCommaSeparatedLiterals(source, configuration.Tags.Value); source.Append(')'); } - if (metadata.Accepts is { Count: > 0 }) - foreach (var accepts in metadata.Accepts.Value) + if (configuration.Accepts is { Count: > 0 }) + foreach (var accepts in configuration.Accepts.Value) { source.AppendLine(); source.Append(indent); @@ -685,8 +677,8 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.Append(')'); } - if (metadata.Produces is { Count: > 0 }) - foreach (var produces in metadata.Produces.Value) + if (configuration.Produces is { Count: > 0 }) + foreach (var produces in configuration.Produces.Value) { source.AppendLine(); source.Append(indent); @@ -699,8 +691,8 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.Append(')'); } - if (metadata.ProducesProblem is { Count: > 0 }) - foreach (var producesProblem in metadata.ProducesProblem.Value) + if (configuration.ProducesProblem is { Count: > 0 }) + foreach (var producesProblem in configuration.ProducesProblem.Value) { source.AppendLine(); source.Append(indent); @@ -710,8 +702,8 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.Append(')'); } - if (metadata.ProducesValidationProblem is { Count: > 0 }) - foreach (var producesValidationProblem in metadata.ProducesValidationProblem.Value) + if (configuration.ProducesValidationProblem is { Count: > 0 }) + foreach (var producesValidationProblem in configuration.ProducesValidationProblem.Value) { source.AppendLine(); source.Append(indent); @@ -836,7 +828,16 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfiguration classConfiguration, EndpointConfiguration methodConfiguration) { - var metadata = MergeRequestHandlerMetadata(classConfiguration.Metadata, methodConfiguration.Metadata); + var name = methodConfiguration.Name ?? classConfiguration.Name; + var displayName = methodConfiguration.DisplayName ?? classConfiguration.DisplayName; + var summary = methodConfiguration.Summary ?? classConfiguration.Summary; + var description = methodConfiguration.Description ?? classConfiguration.Description; + var tags = MergeDistinctStrings(classConfiguration.Tags, methodConfiguration.Tags); + var accepts = ConcatEquatable(classConfiguration.Accepts, methodConfiguration.Accepts); + var produces = ConcatEquatable(classConfiguration.Produces, methodConfiguration.Produces); + var producesProblem = ConcatEquatable(classConfiguration.ProducesProblem, methodConfiguration.ProducesProblem); + var producesValidationProblem = ConcatEquatable(classConfiguration.ProducesValidationProblem, methodConfiguration.ProducesValidationProblem); + var excludeFromDescription = classConfiguration.ExcludeFromDescription || methodConfiguration.ExcludeFromDescription; var authorizationPolicies = MergeDistinctStrings(classConfiguration.AuthorizationPolicies, methodConfiguration.AuthorizationPolicies); var requiredHosts = MergeDistinctStrings(classConfiguration.RequiredHosts, methodConfiguration.RequiredHosts); var endpointFilterTypes = ConcatEquatable(classConfiguration.EndpointFilterTypes, methodConfiguration.EndpointFilterTypes); @@ -866,20 +867,10 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu var order = methodConfiguration.Order ?? classConfiguration.Order; var endpointGroupName = methodConfiguration.EndpointGroupName ?? classConfiguration.EndpointGroupName; - return new EndpointConfiguration(metadata, requireAuthorization, authorizationPolicies, disableAntiforgery, allowAnonymous, requireCors, corsPolicyName, - requiredHosts, requireRateLimiting, rateLimitingPolicyName, endpointFilterTypes, shortCircuit, disableValidation, disableRequestTimeout, - withRequestTimeout, requestTimeoutPolicyName, order, endpointGroupName - ); - } - - private static RequestHandlerMetadata MergeRequestHandlerMetadata(RequestHandlerMetadata classMetadata, RequestHandlerMetadata methodMetadata) - { - return new RequestHandlerMetadata(methodMetadata.Name ?? classMetadata.Name, methodMetadata.DisplayName ?? classMetadata.DisplayName, - methodMetadata.Summary ?? classMetadata.Summary, methodMetadata.Description ?? classMetadata.Description, - MergeDistinctStrings(classMetadata.Tags, methodMetadata.Tags), ConcatEquatable(classMetadata.Accepts, methodMetadata.Accepts), - ConcatEquatable(classMetadata.Produces, methodMetadata.Produces), ConcatEquatable(classMetadata.ProducesProblem, methodMetadata.ProducesProblem), - ConcatEquatable(classMetadata.ProducesValidationProblem, methodMetadata.ProducesValidationProblem), - classMetadata.ExcludeFromDescription || methodMetadata.ExcludeFromDescription + return new EndpointConfiguration(name, displayName, summary, description, tags, accepts, produces, producesProblem, producesValidationProblem, + excludeFromDescription, requireAuthorization, authorizationPolicies, disableAntiforgery, allowAnonymous, requireCors, corsPolicyName, requiredHosts, + requireRateLimiting, rateLimitingPolicyName, endpointFilterTypes, shortCircuit, disableValidation, disableRequestTimeout, withRequestTimeout, + requestTimeoutPolicyName, order, endpointGroupName ); } @@ -952,11 +943,15 @@ private static void AppendCommaSeparatedLiterals(StringBuilder source, Equatable if (values.Count == 0) return; - source.Append(values[0].ToStringLiteral()); + source.Append(values[0] + .ToStringLiteral() + ); for (var i = 1; i < values.Count; i++) { source.Append(", "); - source.Append(values[i].ToStringLiteral()); + source.Append(values[i] + .ToStringLiteral() + ); } } From de7d86a1aa4c8d21bdf75e8e483ec9bd7b6fd64d Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 20:36:14 -0500 Subject: [PATCH 06/32] Refactor. --- src/GeneratedEndpoints/Common/Constants.cs | 6 + .../Common/EndpointAttributeState.cs | 32 -- .../Common/EndpointConfigurationFactory.cs | 318 ++++++++---------- .../Common/NamedTypeSymbolExtensions.cs | 48 ++- ...Kind.cs => RequestHandlerAttributeKind.cs} | 7 +- .../Common/RequestHandlerClassCacheEntry.cs | 2 +- 6 files changed, 186 insertions(+), 227 deletions(-) delete mode 100644 src/GeneratedEndpoints/Common/EndpointAttributeState.cs rename src/GeneratedEndpoints/Common/{GeneratedAttributeKind.cs => RequestHandlerAttributeKind.cs} (74%) diff --git a/src/GeneratedEndpoints/Common/Constants.cs b/src/GeneratedEndpoints/Common/Constants.cs index a75c1c2..351f650 100644 --- a/src/GeneratedEndpoints/Common/Constants.cs +++ b/src/GeneratedEndpoints/Common/Constants.cs @@ -1,3 +1,5 @@ +using System.ComponentModel; + namespace GeneratedEndpoints.Common; internal static partial class Constants @@ -61,7 +63,11 @@ internal static partial class Constants internal const string SummaryAttributeFullyQualifiedName = $"{AttributesNamespace}.{SummaryAttributeName}"; internal const string SummaryAttributeHint = $"{SummaryAttributeFullyQualifiedName}.gs.cs"; + internal const string DisplayNameAttributeName = nameof(DisplayNameAttribute); + internal const string DescriptionAttributeName = nameof(DescriptionAttribute); internal const string AllowAnonymousAttributeName = "AllowAnonymousAttribute"; + internal const string TagsAttributeName = "TagsAttribute"; + internal const string ExcludeFromDescriptionAttributeName = "ExcludeFromDescriptionAttribute"; internal const string EndpointFilterAttributeName = "EndpointFilterAttribute"; internal const string EndpointFilterAttributeFullyQualifiedName = $"{AttributesNamespace}.{EndpointFilterAttributeName}"; diff --git a/src/GeneratedEndpoints/Common/EndpointAttributeState.cs b/src/GeneratedEndpoints/Common/EndpointAttributeState.cs deleted file mode 100644 index 3ac4ed1..0000000 --- a/src/GeneratedEndpoints/Common/EndpointAttributeState.cs +++ /dev/null @@ -1,32 +0,0 @@ -namespace GeneratedEndpoints.Common; - -internal struct EndpointAttributeState -{ - public EquatableImmutableArray? Tags; - public bool? RequireAuthorization; - public EquatableImmutableArray? AuthorizationPolicies; - public bool? DisableAntiforgery; - public bool? AllowAnonymous; - public bool? ExcludeFromDescription; - public List? Accepts; - public List? Produces; - public List? ProducesProblem; - public List? ProducesValidationProblem; - public bool? RequireCors; - public string? CorsPolicyName; - public EquatableImmutableArray? RequiredHosts; - public bool? RequireRateLimiting; - public string? RateLimitingPolicyName; - public List? EndpointFilters; - public HashSet? EndpointFilterSet; - public bool HasAllowAnonymousAttribute; - public bool HasRequireAuthorizationAttribute; - public bool? ShortCircuit; - public bool? DisableValidation; - public bool? DisableRequestTimeout; - public bool? WithRequestTimeout; - public string? RequestTimeoutPolicyName; - public int? Order; - public string? EndpointGroupName; - public string? Summary; -} diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index 23eb87d..a18a16a 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -10,133 +10,42 @@ internal static class EndpointConfigurationFactory { private static readonly ConditionalWeakTable GeneratedAttributeKindCache = new(); - public static EndpointConfiguration Create(ISymbol methodSymbol, string? name, bool enforceMethodRequireAuthorizationRules) + public static EndpointConfiguration Create(ISymbol symbol, string? name, bool enforceMethodRequireAuthorizationRules) { - var attributes = methodSymbol.GetAttributes(); - var (displayName, description) = GetDisplayAndDescriptionAttributes(methodSymbol); + var attributes = symbol.GetAttributes(); - name ??= RemoveAsyncSuffix(methodSymbol.Name); + if (symbol is IMethodSymbol) + name ??= RemoveAsyncSuffix(symbol.Name); - var state = new EndpointAttributeState(); - PopulateAttributeState(attributes, ref state); - - if (enforceMethodRequireAuthorizationRules && state is { HasRequireAuthorizationAttribute: true, HasAllowAnonymousAttribute: false }) - state.AllowAnonymous = false; - - var withRequestTimeout = state.WithRequestTimeout ?? false; - var requestTimeoutPolicyName = withRequestTimeout ? state.RequestTimeoutPolicyName : null; - - return new EndpointConfiguration(name, displayName, state.Summary, description, state.Tags, ToEquatableOrNull(state.Accepts), - ToEquatableOrNull(state.Produces), ToEquatableOrNull(state.ProducesProblem), ToEquatableOrNull(state.ProducesValidationProblem), - state.ExcludeFromDescription ?? false, state.RequireAuthorization ?? false, state.AuthorizationPolicies, state.DisableAntiforgery ?? false, - state.AllowAnonymous ?? false, state.RequireCors ?? false, state.CorsPolicyName, state.RequiredHosts, state.RequireRateLimiting ?? false, - state.RateLimitingPolicyName, ToEquatableOrNull(state.EndpointFilters), state.ShortCircuit ?? false, state.DisableValidation ?? false, - state.DisableRequestTimeout ?? false, withRequestTimeout, requestTimeoutPolicyName, state.Order, state.EndpointGroupName - ); - } - - public static GeneratedAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) - { - var definition = attributeClass.OriginalDefinition; - var cacheEntry = GeneratedAttributeKindCache.GetValue( - definition, static def => new GeneratedAttributeKindCacheEntry(GetGeneratedAttributeKindCore(def)) - ); - - return cacheEntry.Kind; - } - - internal static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, IEnumerable values) - { - List? list = null; - HashSet? seen = null; - - if (existing is { Count: > 0 }) - { - var count = existing.Value.Count; - list = new List(count + 4); - list.AddRange(existing.Value); - seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); - } - - foreach (var value in values) - { - var normalized = value.NormalizeOptionalString(); - if (normalized is not { Length: > 0 }) - continue; - - seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); - if (!seen.Add(normalized)) - continue; - - list ??= []; - list.Add(normalized); - } - - return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; - } - - private static string RemoveAsyncSuffix(string methodName) - { - if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) - return methodName[..^AsyncSuffix.Length]; - - return methodName; - } - - private static (string? DisplayName, string? Description) GetDisplayAndDescriptionAttributes(ISymbol methodSymbol) - { string? displayName = null; string? description = null; - - foreach (var attribute in methodSymbol.GetAttributes()) - { - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DisplayNameAttribute), ComponentModelNamespaceParts)) - { - displayName = attribute.ConstructorArguments.Length > 0 ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() : null; - - continue; - } - - if (AttributeSymbolMatcher.IsAttribute(attributeClass, nameof(DescriptionAttribute), ComponentModelNamespaceParts)) - description = attribute.ConstructorArguments.Length > 0 ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() : null; - } - - return (displayName, description); - } - - private static void PopulateAttributeState(ImmutableArray attributes, ref EndpointAttributeState state) - { - ref var tags = ref state.Tags; - ref var requireAuthorization = ref state.RequireAuthorization; - ref var authorizationPolicies = ref state.AuthorizationPolicies; - ref var disableAntiforgery = ref state.DisableAntiforgery; - ref var allowAnonymous = ref state.AllowAnonymous; - ref var excludeFromDescription = ref state.ExcludeFromDescription; - ref var accepts = ref state.Accepts; - ref var produces = ref state.Produces; - ref var producesProblem = ref state.ProducesProblem; - ref var producesValidationProblem = ref state.ProducesValidationProblem; - ref var requireCors = ref state.RequireCors; - ref var corsPolicyName = ref state.CorsPolicyName; - ref var requiredHosts = ref state.RequiredHosts; - ref var requireRateLimiting = ref state.RequireRateLimiting; - ref var rateLimitingPolicyName = ref state.RateLimitingPolicyName; - ref var endpointFilters = ref state.EndpointFilters; - ref var endpointFilterSet = ref state.EndpointFilterSet; - ref var hasAllowAnonymousAttribute = ref state.HasAllowAnonymousAttribute; - ref var hasRequireAuthorizationAttribute = ref state.HasRequireAuthorizationAttribute; - ref var shortCircuit = ref state.ShortCircuit; - ref var disableValidation = ref state.DisableValidation; - ref var disableRequestTimeout = ref state.DisableRequestTimeout; - ref var withRequestTimeout = ref state.WithRequestTimeout; - ref var requestTimeoutPolicyName = ref state.RequestTimeoutPolicyName; - ref var order = ref state.Order; - ref var endpointGroupName = ref state.EndpointGroupName; - ref var summary = ref state.Summary; + EquatableImmutableArray? tags = null; + bool? requireAuthorization = null; + EquatableImmutableArray? authorizationPolicies = null; + bool? disableAntiforgery = null; + bool? allowAnonymous = null; + bool? excludeFromDescription = null; + List? accepts = null; + List? produces = null; + List? producesProblem = null; + List? producesValidationProblem = null; + bool? requireCors = null; + string? corsPolicyName = null; + EquatableImmutableArray? requiredHosts = null; + bool? requireRateLimiting = null; + string? rateLimitingPolicyName = null; + List? endpointFilters = null; + HashSet? endpointFilterSet = null; + bool hasAllowAnonymousAttribute = false; + bool hasRequireAuthorizationAttribute = false; + bool? shortCircuit = null; + bool? disableValidation = null; + bool? disableRequestTimeout = null; + bool? withRequestTimeout = null; + string? requestTimeoutPolicyName = null; + int? order = null; + string? endpointGroupName = null; + string? summary = null; foreach (var attribute in attributes) { @@ -144,20 +53,21 @@ private static void PopulateAttributeState(ImmutableArray attribu if (attributeClass is null) continue; - switch (GetGeneratedAttributeKind(attributeClass)) + var attributeKind = GetGeneratedAttributeKind(attributeClass); + switch (attributeKind) { - case GeneratedAttributeKind.ShortCircuit: + case RequestHandlerAttributeKind.ShortCircuit: shortCircuit = true; continue; - case GeneratedAttributeKind.DisableValidation: + case RequestHandlerAttributeKind.DisableValidation: disableValidation = true; continue; - case GeneratedAttributeKind.DisableRequestTimeout: + case RequestHandlerAttributeKind.DisableRequestTimeout: disableRequestTimeout = true; withRequestTimeout = false; requestTimeoutPolicyName = null; continue; - case GeneratedAttributeKind.RequestTimeout: + case RequestHandlerAttributeKind.RequestTimeout: { disableRequestTimeout = false; withRequestTimeout = true; @@ -167,18 +77,18 @@ private static void PopulateAttributeState(ImmutableArray attribu requestTimeoutPolicyName = policyName; continue; } - case GeneratedAttributeKind.Order: + case RequestHandlerAttributeKind.Order: if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int orderValue) order = orderValue; continue; - case GeneratedAttributeKind.MapGroup: + case RequestHandlerAttributeKind.MapGroup: { var groupName = attribute.GetNamedStringValue(NameAttributeNamedParameter); if (!string.IsNullOrEmpty(groupName)) endpointGroupName = groupName; continue; } - case GeneratedAttributeKind.Summary: + case RequestHandlerAttributeKind.Summary: if (attribute.ConstructorArguments.Length > 0) { var summaryValue = (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString(); @@ -186,13 +96,13 @@ private static void PopulateAttributeState(ImmutableArray attribu summary = summaryValue; } continue; - case GeneratedAttributeKind.Accepts: + case RequestHandlerAttributeKind.Accepts: TryAddAcceptsMetadata(attribute, attributeClass, ref accepts); continue; - case GeneratedAttributeKind.ProducesResponse: + case RequestHandlerAttributeKind.ProducesResponse: TryAddProducesMetadata(attribute, attributeClass, ref produces); continue; - case GeneratedAttributeKind.RequireAuthorization: + case RequestHandlerAttributeKind.RequireAuthorization: requireAuthorization = true; hasRequireAuthorizationAttribute = true; if (attribute.ConstructorArguments.Length == 1) @@ -202,13 +112,13 @@ private static void PopulateAttributeState(ImmutableArray attribu } continue; - case GeneratedAttributeKind.RequireCors: + case RequestHandlerAttributeKind.RequireCors: requireCors = true; corsPolicyName = attribute.ConstructorArguments.Length > 0 ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() : null; continue; - case GeneratedAttributeKind.RequireHost: + case RequestHandlerAttributeKind.RequireHost: if (attribute.ConstructorArguments.Length == 1) { var arg = attribute.ConstructorArguments[0]; @@ -219,7 +129,7 @@ private static void PopulateAttributeState(ImmutableArray attribu } continue; - case GeneratedAttributeKind.RequireRateLimiting: + case RequestHandlerAttributeKind.RequireRateLimiting: { var policyName = attribute.ConstructorArguments.Length > 0 ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() @@ -233,13 +143,13 @@ private static void PopulateAttributeState(ImmutableArray attribu continue; } - case GeneratedAttributeKind.EndpointFilter: + case RequestHandlerAttributeKind.EndpointFilter: TryAddEndpointFilter(attribute, attributeClass, ref endpointFilters, ref endpointFilterSet); continue; - case GeneratedAttributeKind.DisableAntiforgery: + case RequestHandlerAttributeKind.DisableAntiforgery: disableAntiforgery = true; continue; - case GeneratedAttributeKind.ProducesProblem: + case RequestHandlerAttributeKind.ProducesProblem: { var statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesProblemStatusCode ? producesProblemStatusCode @@ -253,7 +163,7 @@ private static void PopulateAttributeState(ImmutableArray attribu producesProblemList.Add(new ProducesProblemMetadata(statusCode, contentType, additionalContentTypes)); continue; } - case GeneratedAttributeKind.ProducesValidationProblem: + case RequestHandlerAttributeKind.ProducesValidationProblem: { var statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesValidationProblemStatusCode @@ -268,29 +178,94 @@ private static void PopulateAttributeState(ImmutableArray attribu producesValidationProblemList.Add(new ProducesValidationProblemMetadata(statusCode, contentType, additionalContentTypes)); continue; } + case RequestHandlerAttributeKind.DisplayName: + displayName = attribute.ConstructorArguments.Length > 0 ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() : null; + break; + case RequestHandlerAttributeKind.Description: + description = attribute.ConstructorArguments.Length > 0 ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() : null; + break; + case RequestHandlerAttributeKind.AllowAnonymous: + allowAnonymous = true; + hasAllowAnonymousAttribute = true; + break; + case RequestHandlerAttributeKind.Tags: + if (attribute.ConstructorArguments.Length > 0) + { + var arg = attribute.ConstructorArguments[0]; + MergeInto(ref tags, arg.Values); + } + break; + case RequestHandlerAttributeKind.ExcludeFromDescription: + excludeFromDescription = true; + break; + case RequestHandlerAttributeKind.None: + default: + break; } + } - if (AttributeSymbolMatcher.IsAttribute(attributeClass, AllowAnonymousAttributeName, AspNetCoreAuthorizationNamespaceParts)) - { - allowAnonymous = true; - hasAllowAnonymousAttribute = true; - continue; - } - if (AttributeSymbolMatcher.IsAttribute(attributeClass, "TagsAttribute", AspNetCoreHttpNamespaceParts)) - { - if (attribute.ConstructorArguments.Length > 0) - { - var arg = attribute.ConstructorArguments[0]; - MergeInto(ref tags, arg.Values); - } + if (enforceMethodRequireAuthorizationRules && hasRequireAuthorizationAttribute && !hasAllowAnonymousAttribute) + allowAnonymous = false; + var withRequestTimeout1 = withRequestTimeout ?? false; + var requestTimeoutPolicyName1 = withRequestTimeout1 ? requestTimeoutPolicyName : null; + + return new EndpointConfiguration(name, displayName, summary, description, tags, ToEquatableOrNull(accepts), + ToEquatableOrNull(produces), ToEquatableOrNull(producesProblem), ToEquatableOrNull(producesValidationProblem), + excludeFromDescription ?? false, requireAuthorization ?? false, authorizationPolicies, disableAntiforgery ?? false, + allowAnonymous ?? false, requireCors ?? false, corsPolicyName, requiredHosts, requireRateLimiting ?? false, + rateLimitingPolicyName, ToEquatableOrNull(endpointFilters), shortCircuit ?? false, disableValidation ?? false, + disableRequestTimeout ?? false, withRequestTimeout1, requestTimeoutPolicyName1, order, endpointGroupName + ); + } + + public static RequestHandlerAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) + { + var definition = attributeClass.OriginalDefinition; + var cacheEntry = GeneratedAttributeKindCache.GetValue( + definition, static def => new GeneratedAttributeKindCacheEntry(def.GetRequestHandlerAttributeKind()) + ); + + return cacheEntry.Kind; + } + + internal static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, IEnumerable values) + { + List? list = null; + HashSet? seen = null; + + if (existing is { Count: > 0 }) + { + var count = existing.Value.Count; + list = new List(count + 4); + list.AddRange(existing.Value); + seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); + } + + foreach (var value in values) + { + var normalized = value.NormalizeOptionalString(); + if (normalized is not { Length: > 0 }) + continue; + + seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); + if (!seen.Add(normalized)) continue; - } - if (AttributeSymbolMatcher.IsAttribute(attributeClass, "ExcludeFromDescriptionAttribute", AspNetCoreRoutingNamespaceParts)) - excludeFromDescription = true; + list ??= []; + list.Add(normalized); } + + return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; + } + + private static string RemoveAsyncSuffix(string methodName) + { + if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) + return methodName[..^AsyncSuffix.Length]; + + return methodName; } private static void MergeInto(ref EquatableImmutableArray? target, IEnumerable values) @@ -452,36 +427,9 @@ private static void TryAddEndpointFilterType(ITypeSymbol? typeSymbol, ref List GeneratedAttributeKind.ShortCircuit, - DisableValidationAttributeName => GeneratedAttributeKind.DisableValidation, - DisableRequestTimeoutAttributeName => GeneratedAttributeKind.DisableRequestTimeout, - RequestTimeoutAttributeName => GeneratedAttributeKind.RequestTimeout, - OrderAttributeName => GeneratedAttributeKind.Order, - MapGroupAttributeName => GeneratedAttributeKind.MapGroup, - SummaryAttributeName => GeneratedAttributeKind.Summary, - AcceptsAttributeName => GeneratedAttributeKind.Accepts, - ProducesResponseAttributeName => GeneratedAttributeKind.ProducesResponse, - RequireAuthorizationAttributeName => GeneratedAttributeKind.RequireAuthorization, - RequireCorsAttributeName => GeneratedAttributeKind.RequireCors, - RequireHostAttributeName => GeneratedAttributeKind.RequireHost, - RequireRateLimitingAttributeName => GeneratedAttributeKind.RequireRateLimiting, - EndpointFilterAttributeName => GeneratedAttributeKind.EndpointFilter, - DisableAntiforgeryAttributeName => GeneratedAttributeKind.DisableAntiforgery, - ProducesProblemAttributeName => GeneratedAttributeKind.ProducesProblem, - ProducesValidationProblemAttributeName => GeneratedAttributeKind.ProducesValidationProblem, - _ => GeneratedAttributeKind.None, - }; - } - private sealed class GeneratedAttributeKindCacheEntry(GeneratedAttributeKind kind) + private sealed class GeneratedAttributeKindCacheEntry(RequestHandlerAttributeKind kind) { - public GeneratedAttributeKind Kind { get; } = kind; + public RequestHandlerAttributeKind Kind { get; } = kind; } } diff --git a/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs b/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs index d2e5370..36c1bca 100644 --- a/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs +++ b/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs @@ -1,19 +1,51 @@ using Microsoft.CodeAnalysis; +using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints.Common; /// Provides extension methods for working with named type symbols. internal static class NamedTypeSymbolExtensions { - public static bool HasTypeArgument(this INamedTypeSymbol namedTypeSymbol, out ITypeSymbol typeParameter) + public static RequestHandlerAttributeKind GetRequestHandlerAttributeKind(this INamedTypeSymbol definition) { - if (namedTypeSymbol.TypeArguments.Length == 1) - { - typeParameter = namedTypeSymbol.TypeArguments[0]; - return true; - } + if (AttributeSymbolMatcher.IsAttribute(definition, DisplayNameAttributeName, ComponentModelNamespaceParts)) + return RequestHandlerAttributeKind.DisplayName; + + if (AttributeSymbolMatcher.IsAttribute(definition, DescriptionAttributeName, ComponentModelNamespaceParts)) + return RequestHandlerAttributeKind.Description; + + if (AttributeSymbolMatcher.IsAttribute(definition, AllowAnonymousAttributeName, AspNetCoreAuthorizationNamespaceParts)) + return RequestHandlerAttributeKind.AllowAnonymous; + + if (AttributeSymbolMatcher.IsAttribute(definition, TagsAttributeName, AspNetCoreHttpNamespaceParts)) + return RequestHandlerAttributeKind.Tags; - typeParameter = null!; - return false; + if (AttributeSymbolMatcher.IsAttribute(definition, ExcludeFromDescriptionAttributeName, AspNetCoreRoutingNamespaceParts)) + return RequestHandlerAttributeKind.ExcludeFromDescription; + + if (!AttributeSymbolMatcher.IsInNamespace(definition.ContainingNamespace, AttributesNamespaceParts)) + return RequestHandlerAttributeKind.None; + + return definition.Name switch + { + ShortCircuitAttributeName => RequestHandlerAttributeKind.ShortCircuit, + DisableValidationAttributeName => RequestHandlerAttributeKind.DisableValidation, + DisableRequestTimeoutAttributeName => RequestHandlerAttributeKind.DisableRequestTimeout, + RequestTimeoutAttributeName => RequestHandlerAttributeKind.RequestTimeout, + OrderAttributeName => RequestHandlerAttributeKind.Order, + MapGroupAttributeName => RequestHandlerAttributeKind.MapGroup, + SummaryAttributeName => RequestHandlerAttributeKind.Summary, + AcceptsAttributeName => RequestHandlerAttributeKind.Accepts, + ProducesResponseAttributeName => RequestHandlerAttributeKind.ProducesResponse, + RequireAuthorizationAttributeName => RequestHandlerAttributeKind.RequireAuthorization, + RequireCorsAttributeName => RequestHandlerAttributeKind.RequireCors, + RequireHostAttributeName => RequestHandlerAttributeKind.RequireHost, + RequireRateLimitingAttributeName => RequestHandlerAttributeKind.RequireRateLimiting, + EndpointFilterAttributeName => RequestHandlerAttributeKind.EndpointFilter, + DisableAntiforgeryAttributeName => RequestHandlerAttributeKind.DisableAntiforgery, + ProducesProblemAttributeName => RequestHandlerAttributeKind.ProducesProblem, + ProducesValidationProblemAttributeName => RequestHandlerAttributeKind.ProducesValidationProblem, + _ => RequestHandlerAttributeKind.None, + }; } } diff --git a/src/GeneratedEndpoints/Common/GeneratedAttributeKind.cs b/src/GeneratedEndpoints/Common/RequestHandlerAttributeKind.cs similarity index 74% rename from src/GeneratedEndpoints/Common/GeneratedAttributeKind.cs rename to src/GeneratedEndpoints/Common/RequestHandlerAttributeKind.cs index b9af832..1c4ff22 100644 --- a/src/GeneratedEndpoints/Common/GeneratedAttributeKind.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerAttributeKind.cs @@ -1,6 +1,6 @@ namespace GeneratedEndpoints.Common; -internal enum GeneratedAttributeKind +internal enum RequestHandlerAttributeKind { None = 0, ShortCircuit, @@ -20,4 +20,9 @@ internal enum GeneratedAttributeKind DisableAntiforgery, ProducesProblem, ProducesValidationProblem, + DisplayName, + Description, + AllowAnonymous, + Tags, + ExcludeFromDescription } diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs index 52b8250..2f16f98 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs @@ -189,7 +189,7 @@ private static bool MatchesServiceProvider(ITypeSymbol typeSymbol) if (attributeClass is null) continue; - if (EndpointConfigurationFactory.GetGeneratedAttributeKind(attributeClass) != GeneratedAttributeKind.MapGroup) + if (EndpointConfigurationFactory.GetGeneratedAttributeKind(attributeClass) != RequestHandlerAttributeKind.MapGroup) continue; if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is string pattern) From 4eff3a8c485044a8f4e275d998a25dbea4cfa66f Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 21:11:21 -0500 Subject: [PATCH 07/32] Refactored. --- .../Common/AttributeDataExtensions.cs | 51 +++++- .../Common/EndpointConfiguration.cs | 59 +++--- .../Common/EndpointConfigurationFactory.cs | 171 ++++++------------ .../Common/RequestHandlerClassCacheEntry.cs | 2 +- src/GeneratedEndpoints/MinimalApiGenerator.cs | 37 +++- 5 files changed, 171 insertions(+), 149 deletions(-) diff --git a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs index d3ddc4f..8fd6358 100644 --- a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs +++ b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs @@ -4,6 +4,56 @@ namespace GeneratedEndpoints.Common; internal static class AttributeDataExtensions { + public static string? GetConstructorStringValue(this AttributeData attribute, int position = 0) + { + if (attribute.ConstructorArguments.Length > position) + return (attribute.ConstructorArguments[position].Value as string).NormalizeOptionalString(); + + return null; + } + + public static EquatableImmutableArray? GetConstructorStringArray(this AttributeData attribute, int position = 0) + { + if (attribute.ConstructorArguments.Length <= position) + return null; + + var arg = attribute.ConstructorArguments[position]; + if (arg is { Kind: TypedConstantKind.Array, Values.Length: > 0 }) + { + List? normalized = null; + foreach (var value in arg.Values) + { + if (value.Value is not string stringValue) + continue; + + var trimmed = stringValue.NormalizeOptionalString(); + if (trimmed is not { Length: > 0 }) + continue; + + normalized ??= new List(arg.Values.Length); + normalized.Add(trimmed); + } + + if (normalized is { Count: > 0 }) + return normalized.ToEquatableImmutableArray(); + + return null; + } + + if (arg.Value is string singleHost && !string.IsNullOrWhiteSpace(singleHost)) + return new[] { singleHost.Trim() }.ToEquatableImmutableArray(); + + return null; + } + + public static int? GetConstructorIntValue(this AttributeData attribute, int position = 0) + { + if (attribute.ConstructorArguments.Length > position && attribute.ConstructorArguments[position].Value is int value) + return value; + + return null; + } + public static ITypeSymbol? GetNamedTypeSymbol(this AttributeData attribute, string namedParameter) { foreach (var namedArg in attribute.NamedArguments) @@ -36,5 +86,4 @@ public static bool GetNamedBoolValue(this AttributeData attribute, string namedP return null; } - } diff --git a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs index 8289fc6..6ebf038 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs @@ -1,31 +1,32 @@ namespace GeneratedEndpoints.Common; -internal readonly record struct EndpointConfiguration( - string? Name, - string? DisplayName, - string? Summary, - string? Description, - EquatableImmutableArray? Tags, - EquatableImmutableArray? Accepts, - EquatableImmutableArray? Produces, - EquatableImmutableArray? ProducesProblem, - EquatableImmutableArray? ProducesValidationProblem, - bool ExcludeFromDescription, - bool RequireAuthorization, - EquatableImmutableArray? AuthorizationPolicies, - bool DisableAntiforgery, - bool AllowAnonymous, - bool RequireCors, - string? CorsPolicyName, - EquatableImmutableArray? RequiredHosts, - bool RequireRateLimiting, - string? RateLimitingPolicyName, - EquatableImmutableArray? EndpointFilterTypes, - bool ShortCircuit, - bool DisableValidation, - bool DisableRequestTimeout, - bool WithRequestTimeout, - string? RequestTimeoutPolicyName, - int? Order, - string? EndpointGroupName -); +internal readonly record struct EndpointConfiguration +{ + public string? Name { get; init; } + public string? DisplayName { get; init; } + public string? Summary { get; init; } + public string? Description { get; init; } + public EquatableImmutableArray? Tags { get; init; } + public EquatableImmutableArray? Accepts { get; init; } + public EquatableImmutableArray? Produces { get; init; } + public EquatableImmutableArray? ProducesProblem { get; init; } + public EquatableImmutableArray? ProducesValidationProblem { get; init; } + public bool ExcludeFromDescription { get; init; } + public bool RequireAuthorization { get; init; } + public EquatableImmutableArray? AuthorizationPolicies { get; init; } + public bool DisableAntiforgery { get; init; } + public bool AllowAnonymous { get; init; } + public bool RequireCors { get; init; } + public string? CorsPolicyName { get; init; } + public EquatableImmutableArray? RequiredHosts { get; init; } + public bool RequireRateLimiting { get; init; } + public string? RateLimitingPolicyName { get; init; } + public EquatableImmutableArray? EndpointFilterTypes { get; init; } + public bool ShortCircuit { get; init; } + public bool DisableValidation { get; init; } + public bool DisableRequestTimeout { get; init; } + public bool WithRequestTimeout { get; init; } + public string? RequestTimeoutPolicyName { get; init; } + public int? Order { get; init; } + public string? EndpointGroupName { get; init; } +} diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index a18a16a..e516a9a 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -1,5 +1,4 @@ using System.Collections.Immutable; -using System.ComponentModel; using System.Runtime.CompilerServices; using Microsoft.CodeAnalysis; using static GeneratedEndpoints.Common.Constants; @@ -10,7 +9,7 @@ internal static class EndpointConfigurationFactory { private static readonly ConditionalWeakTable GeneratedAttributeKindCache = new(); - public static EndpointConfiguration Create(ISymbol symbol, string? name, bool enforceMethodRequireAuthorizationRules) + public static EndpointConfiguration Create(ISymbol symbol, string? name) { var attributes = symbol.GetAttributes(); @@ -36,8 +35,8 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name, bool en string? rateLimitingPolicyName = null; List? endpointFilters = null; HashSet? endpointFilterSet = null; - bool hasAllowAnonymousAttribute = false; - bool hasRequireAuthorizationAttribute = false; + var hasAllowAnonymousAttribute = false; + var hasRequireAuthorizationAttribute = false; bool? shortCircuit = null; bool? disableValidation = null; bool? disableRequestTimeout = null; @@ -68,33 +67,18 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name, bool en requestTimeoutPolicyName = null; continue; case RequestHandlerAttributeKind.RequestTimeout: - { disableRequestTimeout = false; withRequestTimeout = true; - string? policyName = null; - if (attribute.ConstructorArguments.Length > 0) - policyName = (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString(); - requestTimeoutPolicyName = policyName; + requestTimeoutPolicyName = attribute.GetConstructorStringValue(); continue; - } case RequestHandlerAttributeKind.Order: - if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int orderValue) - order = orderValue; + order = attribute.GetConstructorIntValue(); continue; case RequestHandlerAttributeKind.MapGroup: - { - var groupName = attribute.GetNamedStringValue(NameAttributeNamedParameter); - if (!string.IsNullOrEmpty(groupName)) - endpointGroupName = groupName; + endpointGroupName = attribute.GetNamedStringValue(NameAttributeNamedParameter); continue; - } case RequestHandlerAttributeKind.Summary: - if (attribute.ConstructorArguments.Length > 0) - { - var summaryValue = (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString(); - if (!string.IsNullOrEmpty(summaryValue)) - summary = summaryValue; - } + summary = attribute.GetConstructorStringValue(); continue; case RequestHandlerAttributeKind.Accepts: TryAddAcceptsMetadata(attribute, attributeClass, ref accepts); @@ -105,44 +89,21 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name, bool en case RequestHandlerAttributeKind.RequireAuthorization: requireAuthorization = true; hasRequireAuthorizationAttribute = true; - if (attribute.ConstructorArguments.Length == 1) - { - var arg = attribute.ConstructorArguments[0]; - MergeInto(ref authorizationPolicies, arg.Values); - } - + var authorizationPoliciesValues = attribute.GetConstructorStringArray(); + MergeInto(ref authorizationPolicies, authorizationPoliciesValues); continue; case RequestHandlerAttributeKind.RequireCors: requireCors = true; - corsPolicyName = attribute.ConstructorArguments.Length > 0 - ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() - : null; + corsPolicyName = attribute.GetConstructorStringValue(); continue; case RequestHandlerAttributeKind.RequireHost: - if (attribute.ConstructorArguments.Length == 1) - { - var arg = attribute.ConstructorArguments[0]; - if (arg is { Kind: TypedConstantKind.Array, Values.Length: > 0 }) - MergeInto(ref requiredHosts, arg.Values); - else if (arg.Value is string singleHost && !string.IsNullOrWhiteSpace(singleHost)) - MergeInto(ref requiredHosts, [singleHost.Trim()]); - } - + var requiredHostsValues = attribute.GetConstructorStringArray(); + MergeInto(ref requiredHosts, requiredHostsValues); continue; case RequestHandlerAttributeKind.RequireRateLimiting: - { - var policyName = attribute.ConstructorArguments.Length > 0 - ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() - : null; - - if (!string.IsNullOrEmpty(policyName)) - { - requireRateLimiting = true; - rateLimitingPolicyName = policyName; - } - + requireRateLimiting = true; + rateLimitingPolicyName = attribute.GetConstructorStringValue(); continue; - } case RequestHandlerAttributeKind.EndpointFilter: TryAddEndpointFilter(attribute, attributeClass, ref endpointFilters, ref endpointFilterSet); continue; @@ -151,13 +112,9 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name, bool en continue; case RequestHandlerAttributeKind.ProducesProblem: { - var statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesProblemStatusCode - ? producesProblemStatusCode - : 500; - var contentType = attribute.ConstructorArguments.Length > 1 - ? (attribute.ConstructorArguments[1].Value as string).NormalizeOptionalString() - : null; - var additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; + var statusCode = attribute.GetConstructorIntValue() ?? 500; + var contentType = attribute.GetConstructorStringValue(1); + var additionalContentTypes = attribute.GetConstructorStringArray(2); var producesProblemList = producesProblem ??= []; producesProblemList.Add(new ProducesProblemMetadata(statusCode, contentType, additionalContentTypes)); @@ -165,35 +122,27 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name, bool en } case RequestHandlerAttributeKind.ProducesValidationProblem: { - var statusCode = - attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesValidationProblemStatusCode - ? producesValidationProblemStatusCode - : 400; - var contentType = attribute.ConstructorArguments.Length > 1 - ? (attribute.ConstructorArguments[1].Value as string).NormalizeOptionalString() - : null; - var additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; + var statusCode = attribute.GetConstructorIntValue() ?? 400; + var contentType = attribute.GetConstructorStringValue(1); + var additionalContentTypes = attribute.GetConstructorStringArray(2); var producesValidationProblemList = producesValidationProblem ??= []; producesValidationProblemList.Add(new ProducesValidationProblemMetadata(statusCode, contentType, additionalContentTypes)); continue; } case RequestHandlerAttributeKind.DisplayName: - displayName = attribute.ConstructorArguments.Length > 0 ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() : null; + displayName = attribute.GetConstructorStringValue(); break; case RequestHandlerAttributeKind.Description: - description = attribute.ConstructorArguments.Length > 0 ? (attribute.ConstructorArguments[0].Value as string).NormalizeOptionalString() : null; + description = attribute.GetConstructorStringValue(); break; case RequestHandlerAttributeKind.AllowAnonymous: allowAnonymous = true; hasAllowAnonymousAttribute = true; break; case RequestHandlerAttributeKind.Tags: - if (attribute.ConstructorArguments.Length > 0) - { - var arg = attribute.ConstructorArguments[0]; - MergeInto(ref tags, arg.Values); - } + var tagsValues = attribute.GetConstructorStringArray(); + MergeInto(ref tags, tagsValues); break; case RequestHandlerAttributeKind.ExcludeFromDescription: excludeFromDescription = true; @@ -204,20 +153,39 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name, bool en } } - - if (enforceMethodRequireAuthorizationRules && hasRequireAuthorizationAttribute && !hasAllowAnonymousAttribute) + if (hasRequireAuthorizationAttribute && !hasAllowAnonymousAttribute) allowAnonymous = false; - var withRequestTimeout1 = withRequestTimeout ?? false; - var requestTimeoutPolicyName1 = withRequestTimeout1 ? requestTimeoutPolicyName : null; - - return new EndpointConfiguration(name, displayName, summary, description, tags, ToEquatableOrNull(accepts), - ToEquatableOrNull(produces), ToEquatableOrNull(producesProblem), ToEquatableOrNull(producesValidationProblem), - excludeFromDescription ?? false, requireAuthorization ?? false, authorizationPolicies, disableAntiforgery ?? false, - allowAnonymous ?? false, requireCors ?? false, corsPolicyName, requiredHosts, requireRateLimiting ?? false, - rateLimitingPolicyName, ToEquatableOrNull(endpointFilters), shortCircuit ?? false, disableValidation ?? false, - disableRequestTimeout ?? false, withRequestTimeout1, requestTimeoutPolicyName1, order, endpointGroupName - ); + return new EndpointConfiguration + { + Name = name, + DisplayName = displayName, + Summary = summary, + Description = description, + Tags = tags, + Accepts = ToEquatableOrNull(accepts), + Produces = ToEquatableOrNull(produces), + ProducesProblem = ToEquatableOrNull(producesProblem), + ProducesValidationProblem = ToEquatableOrNull(producesValidationProblem), + ExcludeFromDescription = excludeFromDescription ?? false, + RequireAuthorization = requireAuthorization ?? false, + AuthorizationPolicies = authorizationPolicies, + DisableAntiforgery = disableAntiforgery ?? false, + AllowAnonymous = allowAnonymous ?? false, + RequireCors = requireCors ?? false, + CorsPolicyName = corsPolicyName, + RequiredHosts = requiredHosts, + RequireRateLimiting = requireRateLimiting ?? false, + RateLimitingPolicyName = rateLimitingPolicyName, + EndpointFilterTypes = ToEquatableOrNull(endpointFilters), + ShortCircuit = shortCircuit ?? false, + DisableValidation = disableValidation ?? false, + DisableRequestTimeout = disableRequestTimeout ?? false, + WithRequestTimeout = withRequestTimeout ?? false, + RequestTimeoutPolicyName = requestTimeoutPolicyName, + Order = order, + EndpointGroupName = endpointGroupName, + }; } public static RequestHandlerAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) @@ -268,33 +236,13 @@ private static string RemoveAsyncSuffix(string methodName) return methodName; } - private static void MergeInto(ref EquatableImmutableArray? target, IEnumerable values) - { - var merged = MergeUnion(target, values); - target = merged.Count > 0 ? merged : null; - } - - private static void MergeInto(ref EquatableImmutableArray? target, ImmutableArray values) + private static void MergeInto(ref EquatableImmutableArray? target, IEnumerable? values) { - if (values.IsDefaultOrEmpty) + if (values is null) return; - List? normalized = null; - foreach (var value in values) - { - if (value.Value is not string stringValue) - continue; - - var trimmed = stringValue.NormalizeOptionalString(); - if (trimmed is not { Length: > 0 }) - continue; - - normalized ??= new List(values.Length); - normalized.Add(trimmed); - } - - if (normalized is { Count: > 0 }) - MergeInto(ref target, normalized); + var merged = MergeUnion(target, values); + target = merged.Count > 0 ? merged : null; } private static EquatableImmutableArray? ToEquatableOrNull(List? values) @@ -427,7 +375,6 @@ private static void TryAddEndpointFilterType(ITypeSymbol? typeSymbol, ref List? MergeDistinctStrings(EquatableImmutableArray? first, EquatableImmutableArray? second) From 605f93e62c9535af4bd4c5f1efdf1a8b8c53a842 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 21:28:50 -0500 Subject: [PATCH 08/32] Refactor. --- .../Common/AttributeDataExtensions.cs | 10 ++--- .../Common/EndpointConfigurationFactory.cs | 43 ++++--------------- .../Common/StringExtensions.cs | 5 +++ .../GeneratedEndpoints.Tests.csproj | 9 ++++ 4 files changed, 27 insertions(+), 40 deletions(-) diff --git a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs index 8fd6358..a14d6a1 100644 --- a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs +++ b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs @@ -18,8 +18,11 @@ internal static class AttributeDataExtensions return null; var arg = attribute.ConstructorArguments[position]; - if (arg is { Kind: TypedConstantKind.Array, Values.Length: > 0 }) + if (arg.Kind == TypedConstantKind.Array) { + if (arg.Values.Length == 0) + return null; + List? normalized = null; foreach (var value in arg.Values) { @@ -36,11 +39,8 @@ internal static class AttributeDataExtensions if (normalized is { Count: > 0 }) return normalized.ToEquatableImmutableArray(); - - return null; } - - if (arg.Value is string singleHost && !string.IsNullOrWhiteSpace(singleHost)) + else if (arg.Value is string singleHost && !string.IsNullOrWhiteSpace(singleHost)) return new[] { singleHost.Trim() }.ToEquatableImmutableArray(); return null; diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index e516a9a..5c0ca65 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -35,8 +35,6 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name) string? rateLimitingPolicyName = null; List? endpointFilters = null; HashSet? endpointFilterSet = null; - var hasAllowAnonymousAttribute = false; - var hasRequireAuthorizationAttribute = false; bool? shortCircuit = null; bool? disableValidation = null; bool? disableRequestTimeout = null; @@ -63,13 +61,10 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name) continue; case RequestHandlerAttributeKind.DisableRequestTimeout: disableRequestTimeout = true; - withRequestTimeout = false; - requestTimeoutPolicyName = null; continue; case RequestHandlerAttributeKind.RequestTimeout: - disableRequestTimeout = false; - withRequestTimeout = true; requestTimeoutPolicyName = attribute.GetConstructorStringValue(); + withRequestTimeout = true; continue; case RequestHandlerAttributeKind.Order: order = attribute.GetConstructorIntValue(); @@ -87,22 +82,19 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name) TryAddProducesMetadata(attribute, attributeClass, ref produces); continue; case RequestHandlerAttributeKind.RequireAuthorization: + authorizationPolicies = attribute.GetConstructorStringArray(); requireAuthorization = true; - hasRequireAuthorizationAttribute = true; - var authorizationPoliciesValues = attribute.GetConstructorStringArray(); - MergeInto(ref authorizationPolicies, authorizationPoliciesValues); continue; case RequestHandlerAttributeKind.RequireCors: - requireCors = true; corsPolicyName = attribute.GetConstructorStringValue(); + requireCors = true; continue; case RequestHandlerAttributeKind.RequireHost: - var requiredHostsValues = attribute.GetConstructorStringArray(); - MergeInto(ref requiredHosts, requiredHostsValues); + requiredHosts = attribute.GetConstructorStringArray(); continue; case RequestHandlerAttributeKind.RequireRateLimiting: - requireRateLimiting = true; rateLimitingPolicyName = attribute.GetConstructorStringValue(); + requireRateLimiting = rateLimitingPolicyName is not null; continue; case RequestHandlerAttributeKind.EndpointFilter: TryAddEndpointFilter(attribute, attributeClass, ref endpointFilters, ref endpointFilterSet); @@ -138,11 +130,9 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name) break; case RequestHandlerAttributeKind.AllowAnonymous: allowAnonymous = true; - hasAllowAnonymousAttribute = true; break; case RequestHandlerAttributeKind.Tags: - var tagsValues = attribute.GetConstructorStringArray(); - MergeInto(ref tags, tagsValues); + tags = attribute.GetConstructorStringArray(); break; case RequestHandlerAttributeKind.ExcludeFromDescription: excludeFromDescription = true; @@ -153,9 +143,6 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name) } } - if (hasRequireAuthorizationAttribute && !hasAllowAnonymousAttribute) - allowAnonymous = false; - return new EndpointConfiguration { Name = name, @@ -236,25 +223,11 @@ private static string RemoveAsyncSuffix(string methodName) return methodName; } - private static void MergeInto(ref EquatableImmutableArray? target, IEnumerable? values) - { - if (values is null) - return; - - var merged = MergeUnion(target, values); - target = merged.Count > 0 ? merged : null; - } - private static EquatableImmutableArray? ToEquatableOrNull(List? values) { return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null; } - private static string NormalizeRequiredContentType(string? contentType, string defaultValue) - { - return string.IsNullOrWhiteSpace(contentType) ? defaultValue : contentType!.Trim(); - } - private static EquatableImmutableArray? GetStringArrayValues(TypedConstant typedConstant) { if (typedConstant.Kind != TypedConstantKind.Array || typedConstant.Values.IsDefaultOrEmpty) @@ -282,7 +255,7 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym requestType = attributeClass.TypeArguments[0] .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); contentType = attribute.ConstructorArguments.Length > 0 - ? NormalizeRequiredContentType(attribute.ConstructorArguments[0].Value as string, "application/json") + ? (attribute.ConstructorArguments[0].Value as string).NormalizeOrDefaultString("application/json") : "application/json"; additionalContentTypes = attribute.ConstructorArguments.Length > 1 ? GetStringArrayValues(attribute.ConstructorArguments[1]) : null; } @@ -290,7 +263,7 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym { requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); contentType = attribute.ConstructorArguments.Length > 0 - ? NormalizeRequiredContentType(attribute.ConstructorArguments[0].Value as string, "application/json") + ? (attribute.ConstructorArguments[0].Value as string).NormalizeOrDefaultString("application/json") : "application/json"; additionalContentTypes = attribute.ConstructorArguments.Length > 1 ? GetStringArrayValues(attribute.ConstructorArguments[1]) : null; } diff --git a/src/GeneratedEndpoints/Common/StringExtensions.cs b/src/GeneratedEndpoints/Common/StringExtensions.cs index 0c08744..5867811 100644 --- a/src/GeneratedEndpoints/Common/StringExtensions.cs +++ b/src/GeneratedEndpoints/Common/StringExtensions.cs @@ -70,4 +70,9 @@ public static string ToStringLiteral(this string? value) { return string.IsNullOrWhiteSpace(value) ? null : value!.Trim(); } + + public static string NormalizeOrDefaultString(this string? value, string defaultValue) + { + return string.IsNullOrWhiteSpace(value) ? defaultValue : value!.Trim(); + } } diff --git a/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj b/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj index 1f74d48..3f4e8b2 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj +++ b/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj @@ -38,4 +38,13 @@ + + + GeneratedSourceTests.cs + + + IndividualTests.cs + + + From 7fb06b60fb8994ba29b81259188f021c9d845a47 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 21:43:07 -0500 Subject: [PATCH 09/32] Refactored. --- src/GeneratedEndpoints/Common/Constants.cs | 1 + .../Common/EndpointConfigurationFactory.cs | 111 ++++-------------- src/GeneratedEndpoints/MinimalApiGenerator.cs | 32 ++++- 3 files changed, 52 insertions(+), 92 deletions(-) diff --git a/src/GeneratedEndpoints/Common/Constants.cs b/src/GeneratedEndpoints/Common/Constants.cs index 351f650..0b17341 100644 --- a/src/GeneratedEndpoints/Common/Constants.cs +++ b/src/GeneratedEndpoints/Common/Constants.cs @@ -102,6 +102,7 @@ internal static partial class Constants internal const string ConfigureMethodName = "Configure"; internal const string AsyncSuffix = "Async"; internal const string GlobalPrefix = "global::"; + internal const string ApplicationJsonContentType = "application/json"; internal static readonly string[] AttributesNamespaceParts = AttributesNamespace.Split('.'); internal static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"]; diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index 5c0ca65..d6c8104 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -107,9 +107,10 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name) var statusCode = attribute.GetConstructorIntValue() ?? 500; var contentType = attribute.GetConstructorStringValue(1); var additionalContentTypes = attribute.GetConstructorStringArray(2); + var producesProblemMetadata = new ProducesProblemMetadata(statusCode, contentType, additionalContentTypes); var producesProblemList = producesProblem ??= []; - producesProblemList.Add(new ProducesProblemMetadata(statusCode, contentType, additionalContentTypes)); + producesProblemList.Add(producesProblemMetadata); continue; } case RequestHandlerAttributeKind.ProducesValidationProblem: @@ -117,9 +118,10 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name) var statusCode = attribute.GetConstructorIntValue() ?? 400; var contentType = attribute.GetConstructorStringValue(1); var additionalContentTypes = attribute.GetConstructorStringArray(2); + var producesValidationProblemMetadata = new ProducesValidationProblemMetadata(statusCode, contentType, additionalContentTypes); var producesValidationProblemList = producesValidationProblem ??= []; - producesValidationProblemList.Add(new ProducesValidationProblemMetadata(statusCode, contentType, additionalContentTypes)); + producesValidationProblemList.Add(producesValidationProblemMetadata); continue; } case RequestHandlerAttributeKind.DisplayName: @@ -185,36 +187,6 @@ public static RequestHandlerAttributeKind GetGeneratedAttributeKind(INamedTypeSy return cacheEntry.Kind; } - internal static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, IEnumerable values) - { - List? list = null; - HashSet? seen = null; - - if (existing is { Count: > 0 }) - { - var count = existing.Value.Count; - list = new List(count + 4); - list.AddRange(existing.Value); - seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); - } - - foreach (var value in values) - { - var normalized = value.NormalizeOptionalString(); - if (normalized is not { Length: > 0 }) - continue; - - seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); - if (!seen.Add(normalized)) - continue; - - list ??= []; - list.Add(normalized); - } - - return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; - } - private static string RemoveAsyncSuffix(string methodName) { if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) @@ -228,87 +200,44 @@ private static string RemoveAsyncSuffix(string methodName) return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null; } - private static EquatableImmutableArray? GetStringArrayValues(TypedConstant typedConstant) - { - if (typedConstant.Kind != TypedConstantKind.Array || typedConstant.Values.IsDefaultOrEmpty) - return null; - - var builder = ImmutableArray.CreateBuilder(typedConstant.Values.Length); - foreach (var value in typedConstant.Values) - { - if (value.Value is string s && !string.IsNullOrWhiteSpace(s)) - builder.Add(s.Trim()); - } - - return builder.Count > 0 ? builder.ToEquatableImmutable() : null; - } - private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? accepts) { string? requestType; - string contentType; - EquatableImmutableArray? additionalContentTypes; - var isOptional = attribute.GetNamedBoolValue(IsOptionalAttributeNamedParameter); - if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - { - requestType = attributeClass.TypeArguments[0] - .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - contentType = attribute.ConstructorArguments.Length > 0 - ? (attribute.ConstructorArguments[0].Value as string).NormalizeOrDefaultString("application/json") - : "application/json"; - additionalContentTypes = attribute.ConstructorArguments.Length > 1 ? GetStringArrayValues(attribute.ConstructorArguments[1]) : null; - } + requestType = attributeClass.TypeArguments[0].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); else if (attribute.GetNamedTypeSymbol(RequestTypeAttributeNamedParameter) is { } requestTypeSymbol) - { requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - contentType = attribute.ConstructorArguments.Length > 0 - ? (attribute.ConstructorArguments[0].Value as string).NormalizeOrDefaultString("application/json") - : "application/json"; - additionalContentTypes = attribute.ConstructorArguments.Length > 1 ? GetStringArrayValues(attribute.ConstructorArguments[1]) : null; - } else - { return; - } + + var contentType = attribute.GetConstructorStringValue() ?? ApplicationJsonContentType; + var additionalContentTypes = attribute.GetConstructorStringArray(position: 1); + var isOptional = attribute.GetNamedBoolValue(IsOptionalAttributeNamedParameter); + + var acceptMetadata = new AcceptsMetadata(requestType, contentType, additionalContentTypes, isOptional); var acceptsList = accepts ??= []; - acceptsList.Add(new AcceptsMetadata(requestType, contentType, additionalContentTypes, isOptional)); + acceptsList.Add(acceptMetadata); } private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? produces) { string? responseType; - int statusCode; - string? contentType; - EquatableImmutableArray? additionalContentTypes; - if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - { - responseType = attributeClass.TypeArguments[0] - .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode - ? producesStatusCode - : 200; - contentType = attribute.ConstructorArguments.Length > 1 ? (attribute.ConstructorArguments[1].Value as string).NormalizeOptionalString() : null; - additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; - } + responseType = attributeClass.TypeArguments[0].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); else if (attribute.GetNamedTypeSymbol(ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol) - { responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode - ? producesStatusCode - : 200; - contentType = attribute.ConstructorArguments.Length > 1 ? (attribute.ConstructorArguments[1].Value as string).NormalizeOptionalString() : null; - additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; - } else - { return; - } + + var statusCode = attribute.GetConstructorIntValue() ?? 200; + var contentType = attribute.GetConstructorStringValue(position: 1); + var additionalContentTypes = attribute.GetConstructorStringArray(position: 2); + + var producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes); var producesList = produces ??= []; - producesList.Add(new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes)); + producesList.Add(producesMetadata); } private static void TryAddEndpointFilter( diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 6a550ec..a91e391 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -906,10 +906,40 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu if (second is not { Count: > 0 }) return first; - var merged = EndpointConfigurationFactory.MergeUnion(first, second.Value); + var merged = MergeUnion(first, second.Value); return merged.Count > 0 ? merged : null; } + private static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, EquatableImmutableArray values) + { + List? list = null; + HashSet? seen = null; + + if (existing is { Count: > 0 }) + { + var count = existing.Value.Count; + list = new List(count + 4); + list.AddRange(existing.Value); + seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); + } + + foreach (var value in values) + { + var normalized = value.NormalizeOptionalString(); + if (normalized is not { Length: > 0 }) + continue; + + seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); + if (!seen.Add(normalized)) + continue; + + list ??= []; + list.Add(normalized); + } + + return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; + } + private static EquatableImmutableArray? ConcatEquatable(EquatableImmutableArray? first, EquatableImmutableArray? second) { if (first is not { Count: > 0 }) From ce11fd61a1a989cf5bf42b858635e03e72415b07 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 22:07:39 -0500 Subject: [PATCH 10/32] Refactor. --- src/GeneratedEndpoints/Common/Constants.cs | 45 +++++++-------- .../Common/EndpointConfiguration.cs | 56 ++++++++++--------- .../Common/EndpointConfigurationFactory.cs | 35 ++++++++++-- .../Common/RequestHandlerClass.cs | 2 - .../Common/RequestHandlerClassCacheEntry.cs | 37 +----------- src/GeneratedEndpoints/MinimalApiGenerator.cs | 23 +++++--- .../GeneratedEndpoints.Tests.csproj | 9 +++ .../IndividualTests.cs | 7 --- 8 files changed, 107 insertions(+), 107 deletions(-) diff --git a/src/GeneratedEndpoints/Common/Constants.cs b/src/GeneratedEndpoints/Common/Constants.cs index 0b17341..414a6a8 100644 --- a/src/GeneratedEndpoints/Common/Constants.cs +++ b/src/GeneratedEndpoints/Common/Constants.cs @@ -4,8 +4,8 @@ namespace GeneratedEndpoints.Common; internal static partial class Constants { - internal const string BaseNamespace = "Microsoft.AspNetCore.Generated"; - internal const string AttributesNamespace = $"{BaseNamespace}.Attributes"; + private const string BaseNamespace = "Microsoft.AspNetCore.Generated"; + private const string AttributesNamespace = $"{BaseNamespace}.Attributes"; internal const string FallbackHttpMethod = "__FALLBACK__"; @@ -13,54 +13,53 @@ internal static partial class Constants internal const string ResponseTypeAttributeNamedParameter = "ResponseType"; internal const string RequestTypeAttributeNamedParameter = "RequestType"; internal const string IsOptionalAttributeNamedParameter = "IsOptional"; - internal const string PolicyNameAttributeNamedParameter = "PolicyName"; internal const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute"; - internal const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}"; + private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}"; internal const string RequireAuthorizationAttributeHint = $"{RequireAuthorizationAttributeFullyQualifiedName}.gs.cs"; internal const string RequireCorsAttributeName = "RequireCorsAttribute"; - internal const string RequireCorsAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireCorsAttributeName}"; + private const string RequireCorsAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireCorsAttributeName}"; internal const string RequireCorsAttributeHint = $"{RequireCorsAttributeFullyQualifiedName}.gs.cs"; internal const string RequireRateLimitingAttributeName = "RequireRateLimitingAttribute"; - internal const string RequireRateLimitingAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireRateLimitingAttributeName}"; + private const string RequireRateLimitingAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireRateLimitingAttributeName}"; internal const string RequireRateLimitingAttributeHint = $"{RequireRateLimitingAttributeFullyQualifiedName}.gs.cs"; internal const string RequireHostAttributeName = "RequireHostAttribute"; - internal const string RequireHostAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireHostAttributeName}"; + private const string RequireHostAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireHostAttributeName}"; internal const string RequireHostAttributeHint = $"{RequireHostAttributeFullyQualifiedName}.gs.cs"; internal const string DisableAntiforgeryAttributeName = "DisableAntiforgeryAttribute"; - internal const string DisableAntiforgeryAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableAntiforgeryAttributeName}"; + private const string DisableAntiforgeryAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableAntiforgeryAttributeName}"; internal const string DisableAntiforgeryAttributeHint = $"{DisableAntiforgeryAttributeFullyQualifiedName}.gs.cs"; internal const string ShortCircuitAttributeName = "ShortCircuitAttribute"; - internal const string ShortCircuitAttributeFullyQualifiedName = $"{AttributesNamespace}.{ShortCircuitAttributeName}"; + private const string ShortCircuitAttributeFullyQualifiedName = $"{AttributesNamespace}.{ShortCircuitAttributeName}"; internal const string ShortCircuitAttributeHint = $"{ShortCircuitAttributeFullyQualifiedName}.gs.cs"; internal const string DisableRequestTimeoutAttributeName = "DisableRequestTimeoutAttribute"; - internal const string DisableRequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableRequestTimeoutAttributeName}"; + private const string DisableRequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableRequestTimeoutAttributeName}"; internal const string DisableRequestTimeoutAttributeHint = $"{DisableRequestTimeoutAttributeFullyQualifiedName}.gs.cs"; internal const string DisableValidationAttributeName = "DisableValidationAttribute"; - internal const string DisableValidationAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableValidationAttributeName}"; + private const string DisableValidationAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableValidationAttributeName}"; internal const string DisableValidationAttributeHint = $"{DisableValidationAttributeFullyQualifiedName}.gs.cs"; internal const string RequestTimeoutAttributeName = "RequestTimeoutAttribute"; - internal const string RequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequestTimeoutAttributeName}"; + private const string RequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequestTimeoutAttributeName}"; internal const string RequestTimeoutAttributeHint = $"{RequestTimeoutAttributeFullyQualifiedName}.gs.cs"; internal const string OrderAttributeName = "OrderAttribute"; - internal const string OrderAttributeFullyQualifiedName = $"{AttributesNamespace}.{OrderAttributeName}"; + private const string OrderAttributeFullyQualifiedName = $"{AttributesNamespace}.{OrderAttributeName}"; internal const string OrderAttributeHint = $"{OrderAttributeFullyQualifiedName}.gs.cs"; internal const string MapGroupAttributeName = "MapGroupAttribute"; - internal const string MapGroupAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapGroupAttributeName}"; + private const string MapGroupAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapGroupAttributeName}"; internal const string MapGroupAttributeHint = $"{MapGroupAttributeFullyQualifiedName}.gs.cs"; internal const string SummaryAttributeName = "SummaryAttribute"; - internal const string SummaryAttributeFullyQualifiedName = $"{AttributesNamespace}.{SummaryAttributeName}"; + private const string SummaryAttributeFullyQualifiedName = $"{AttributesNamespace}.{SummaryAttributeName}"; internal const string SummaryAttributeHint = $"{SummaryAttributeFullyQualifiedName}.gs.cs"; internal const string DisplayNameAttributeName = nameof(DisplayNameAttribute); @@ -70,34 +69,36 @@ internal static partial class Constants internal const string ExcludeFromDescriptionAttributeName = "ExcludeFromDescriptionAttribute"; internal const string EndpointFilterAttributeName = "EndpointFilterAttribute"; - internal const string EndpointFilterAttributeFullyQualifiedName = $"{AttributesNamespace}.{EndpointFilterAttributeName}"; + private const string EndpointFilterAttributeFullyQualifiedName = $"{AttributesNamespace}.{EndpointFilterAttributeName}"; internal const string EndpointFilterAttributeHint = $"{EndpointFilterAttributeFullyQualifiedName}.gs.cs"; internal const string AcceptsAttributeName = "AcceptsAttribute"; - internal const string AcceptsAttributeFullyQualifiedName = $"{AttributesNamespace}.{AcceptsAttributeName}"; + private const string AcceptsAttributeFullyQualifiedName = $"{AttributesNamespace}.{AcceptsAttributeName}"; internal const string AcceptsAttributeHint = $"{AcceptsAttributeFullyQualifiedName}.gs.cs"; internal const string ProducesResponseAttributeName = "ProducesResponseAttribute"; - internal const string ProducesResponseAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesResponseAttributeName}"; + private const string ProducesResponseAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesResponseAttributeName}"; internal const string ProducesResponseAttributeHint = $"{ProducesResponseAttributeFullyQualifiedName}.gs.cs"; internal const string ProducesProblemAttributeName = "ProducesProblemAttribute"; - internal const string ProducesProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesProblemAttributeName}"; + private const string ProducesProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesProblemAttributeName}"; internal const string ProducesProblemAttributeHint = $"{ProducesProblemAttributeFullyQualifiedName}.gs.cs"; internal const string ProducesValidationProblemAttributeName = "ProducesValidationProblemAttribute"; - internal const string ProducesValidationProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesValidationProblemAttributeName}"; + private const string ProducesValidationProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesValidationProblemAttributeName}"; internal const string ProducesValidationProblemAttributeHint = $"{ProducesValidationProblemAttributeFullyQualifiedName}.gs.cs"; internal const string RoutingNamespace = $"{BaseNamespace}.Routing"; internal const string AddEndpointHandlersClassName = "EndpointServicesExtensions"; internal const string AddEndpointHandlersMethodName = "AddEndpointHandlers"; - internal const string AddEndpointHandlersMethodHint = $"{RoutingNamespace}.{AddEndpointHandlersMethodName}.g.cs"; + private const string AddEndpointHandlersMethodFullyQualifiedName = $"{RoutingNamespace}.{AddEndpointHandlersMethodName}"; + internal const string AddEndpointHandlersMethodHint = $"{AddEndpointHandlersMethodFullyQualifiedName}.g.cs"; internal const string UseEndpointHandlersClassName = "EndpointRouteBuilderExtensions"; internal const string UseEndpointHandlersMethodName = "MapEndpointHandlers"; - internal const string UseEndpointHandlersMethodHint = $"{RoutingNamespace}.{UseEndpointHandlersMethodName}.g.cs"; + private const string UseEndpointHandlersMethodFullyQualifiedName = $"{RoutingNamespace}.{UseEndpointHandlersMethodName}"; + internal const string UseEndpointHandlersMethodHint = $"{UseEndpointHandlersMethodFullyQualifiedName}.g.cs"; internal const string ConfigureMethodName = "Configure"; internal const string AsyncSuffix = "Async"; diff --git a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs index 6ebf038..2a79eaa 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs @@ -2,31 +2,33 @@ namespace GeneratedEndpoints.Common; internal readonly record struct EndpointConfiguration { - public string? Name { get; init; } - public string? DisplayName { get; init; } - public string? Summary { get; init; } - public string? Description { get; init; } - public EquatableImmutableArray? Tags { get; init; } - public EquatableImmutableArray? Accepts { get; init; } - public EquatableImmutableArray? Produces { get; init; } - public EquatableImmutableArray? ProducesProblem { get; init; } - public EquatableImmutableArray? ProducesValidationProblem { get; init; } - public bool ExcludeFromDescription { get; init; } - public bool RequireAuthorization { get; init; } - public EquatableImmutableArray? AuthorizationPolicies { get; init; } - public bool DisableAntiforgery { get; init; } - public bool AllowAnonymous { get; init; } - public bool RequireCors { get; init; } - public string? CorsPolicyName { get; init; } - public EquatableImmutableArray? RequiredHosts { get; init; } - public bool RequireRateLimiting { get; init; } - public string? RateLimitingPolicyName { get; init; } - public EquatableImmutableArray? EndpointFilterTypes { get; init; } - public bool ShortCircuit { get; init; } - public bool DisableValidation { get; init; } - public bool DisableRequestTimeout { get; init; } - public bool WithRequestTimeout { get; init; } - public string? RequestTimeoutPolicyName { get; init; } - public int? Order { get; init; } - public string? EndpointGroupName { get; init; } + public required string? Name { get; init; } + public required string? DisplayName { get; init; } + public required string? Summary { get; init; } + public required string? Description { get; init; } + public required EquatableImmutableArray? Tags { get; init; } + public required EquatableImmutableArray? Accepts { get; init; } + public required EquatableImmutableArray? Produces { get; init; } + public required EquatableImmutableArray? ProducesProblem { get; init; } + public required EquatableImmutableArray? ProducesValidationProblem { get; init; } + public required bool ExcludeFromDescription { get; init; } + public required bool RequireAuthorization { get; init; } + public required EquatableImmutableArray? AuthorizationPolicies { get; init; } + public required bool DisableAntiforgery { get; init; } + public required bool AllowAnonymous { get; init; } + public required bool RequireCors { get; init; } + public required string? CorsPolicyName { get; init; } + public required EquatableImmutableArray? RequiredHosts { get; init; } + public required bool RequireRateLimiting { get; init; } + public required string? RateLimitingPolicyName { get; init; } + public required EquatableImmutableArray? EndpointFilterTypes { get; init; } + public required bool ShortCircuit { get; init; } + public required bool DisableValidation { get; init; } + public required bool DisableRequestTimeout { get; init; } + public required bool WithRequestTimeout { get; init; } + public required string? RequestTimeoutPolicyName { get; init; } + public required int? Order { get; init; } + public required string? GroupIdentifier { get; init; } + public required string? GroupPattern { get; init; } + public required string? GroupName { get; init; } } diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index d6c8104..9d7f2fa 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -11,6 +11,7 @@ internal static class EndpointConfigurationFactory public static EndpointConfiguration Create(ISymbol symbol, string? name) { + var attributes = symbol.GetAttributes(); if (symbol is IMethodSymbol) @@ -41,7 +42,9 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name) bool? withRequestTimeout = null; string? requestTimeoutPolicyName = null; int? order = null; - string? endpointGroupName = null; + string? groupIdentifier = null; + string? groupPattern = null; + string? groupName = null; string? summary = null; foreach (var attribute in attributes) @@ -70,7 +73,9 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name) order = attribute.GetConstructorIntValue(); continue; case RequestHandlerAttributeKind.MapGroup: - endpointGroupName = attribute.GetNamedStringValue(NameAttributeNamedParameter); + groupIdentifier = GetMapGroupIdentifier(symbol); + groupPattern = attribute.GetConstructorStringValue() ?? ""; + groupName = attribute.GetNamedStringValue(NameAttributeNamedParameter); continue; case RequestHandlerAttributeKind.Summary: summary = attribute.GetConstructorStringValue(); @@ -173,11 +178,33 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name) WithRequestTimeout = withRequestTimeout ?? false, RequestTimeoutPolicyName = requestTimeoutPolicyName, Order = order, - EndpointGroupName = endpointGroupName, + GroupIdentifier = groupIdentifier, + GroupPattern = groupPattern, + GroupName = groupName, }; } - public static RequestHandlerAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) + private static string? GetMapGroupIdentifier(ISymbol symbol) + { + if (symbol is not INamedTypeSymbol namedTypeSymbol) + return null; + + var className = namedTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + if (className.StartsWith(GlobalPrefix, StringComparison.Ordinal)) + className = className[GlobalPrefix.Length..]; + + var builder = StringBuilderPool.Get(className.Length + 8); + builder.Append('_'); + + foreach (var character in className) + builder.Append(char.IsLetterOrDigit(character) ? character : '_'); + + builder.Append("_Group"); + return StringBuilderPool.ToStringAndReturn(builder); + } + + private static RequestHandlerAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) { var definition = attributeClass.OriginalDefinition; var cacheEntry = GeneratedAttributeKindCache.GetValue( diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClass.cs b/src/GeneratedEndpoints/Common/RequestHandlerClass.cs index baae895..c5e868b 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClass.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClass.cs @@ -5,7 +5,5 @@ internal readonly record struct RequestHandlerClass( bool IsStatic, bool HasConfigureMethod, bool ConfigureMethodAcceptsServiceProvider, - string? MapGroupPattern, - string? MapGroupBuilderIdentifier, EndpointConfiguration Configuration ); diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs index c4682ec..6d8d368 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs @@ -28,12 +28,10 @@ public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, Compilation compilationCache.ServiceProviderSymbol, cancellationToken ); - var mapGroupPattern = GetMapGroupPattern(classSymbol); - var mapGroupIdentifier = mapGroupPattern is null ? null : GetMapGroupIdentifier(name); var classConfiguration = EndpointConfigurationFactory.Create(classSymbol, null); _value = new RequestHandlerClass(name, isStatic, configureMethodDetails.HasConfigureMethod, - configureMethodDetails.ConfigureMethodAcceptsServiceProvider, mapGroupPattern, mapGroupIdentifier, classConfiguration + configureMethodDetails.ConfigureMethodAcceptsServiceProvider, classConfiguration ); _initialized = true; return _value; @@ -180,37 +178,4 @@ private static bool MatchesServiceProvider(ITypeSymbol typeSymbol) var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? ""; return string.Equals(containingNamespace, "System", StringComparison.Ordinal); } - - private static string? GetMapGroupPattern(INamedTypeSymbol classSymbol) - { - foreach (var attribute in classSymbol.GetAttributes()) - { - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - if (EndpointConfigurationFactory.GetGeneratedAttributeKind(attributeClass) != RequestHandlerAttributeKind.MapGroup) - continue; - - if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is string pattern) - return pattern.Trim(); - } - - return null; - } - - private static string GetMapGroupIdentifier(string className) - { - if (className.StartsWith(GlobalPrefix, StringComparison.Ordinal)) - className = className[GlobalPrefix.Length..]; - - var builder = StringBuilderPool.Get(className.Length + 8); - builder.Append('_'); - - foreach (var character in className) - builder.Append(char.IsLetterOrDigit(character) ? character : '_'); - - builder.Append("_Group"); - return StringBuilderPool.ToStringAndReturn(builder); - } } diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index a91e391..dab9a41 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -420,9 +420,9 @@ private static void GenerateUseEndpointHandlersClass(SourceProductionContext con { var groupedClass = groupedClasses[index]; source.Append(" var "); - source.Append(groupedClass.MapGroupBuilderIdentifier); + source.Append(groupedClass.Configuration.GroupIdentifier); source.Append(" = builder.MapGroup("); - source.Append(groupedClass.MapGroupPattern!.ToStringLiteral()); + source.Append(groupedClass.Configuration.GroupPattern!.ToStringLiteral()); source.Append(')'); AppendEndpointConfiguration(source, " ", groupedClass.Configuration, false); source.AppendLine(";"); @@ -475,7 +475,7 @@ private static List GetClassesWithMapGroups(ImmutableArray< { var handler = requestHandlers[index]; var handlerClass = handler.Class; - if (handlerClass.MapGroupPattern is null) + if (handlerClass.Configuration.GroupPattern is null) continue; if (seen.Add(handlerClass.Name)) @@ -491,7 +491,7 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl var configureAcceptsServiceProvider = requestHandler.Class.ConfigureMethodAcceptsServiceProvider; var indent = wrapWithConfigure ? " " : " "; var continuationIndent = indent + " "; - var routeBuilderIdentifier = requestHandler.Class.MapGroupBuilderIdentifier ?? "builder"; + var routeBuilderIdentifier = requestHandler.Class.Configuration.GroupIdentifier ?? "builder"; if (wrapWithConfigure) { @@ -565,7 +565,7 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl source.Append(')'); var configuration = requestHandler.Configuration; - if (requestHandler.Class.MapGroupPattern is null) + if (requestHandler.Class.Configuration.GroupPattern is null) configuration = MergeEndpointConfigurations(requestHandler.Class.Configuration, configuration); AppendEndpointConfiguration(source, continuationIndent, configuration, true); @@ -627,12 +627,12 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.Append(')'); } - if (!string.IsNullOrEmpty(configuration.EndpointGroupName)) + if (!string.IsNullOrEmpty(configuration.GroupName)) { source.AppendLine(); source.Append(indent); source.Append(".WithGroupName("); - source.Append(configuration.EndpointGroupName.ToStringLiteral()); + source.Append(configuration.GroupName.ToStringLiteral()); source.Append(')'); } @@ -852,6 +852,10 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu var disableValidation = classConfiguration.DisableValidation || methodConfiguration.DisableValidation; var disableRequestTimeout = classConfiguration.DisableRequestTimeout || methodConfiguration.DisableRequestTimeout; var withRequestTimeout = classConfiguration.WithRequestTimeout || methodConfiguration.WithRequestTimeout; + var groupIdentifier = classConfiguration.GroupIdentifier ?? methodConfiguration.GroupIdentifier; + var groupPattern = classConfiguration.GroupPattern ?? methodConfiguration.GroupPattern; + var groupName = classConfiguration.GroupName ?? methodConfiguration.GroupName; + string? requestTimeoutPolicyName = null; if (methodConfiguration.WithRequestTimeout) requestTimeoutPolicyName = methodConfiguration.RequestTimeoutPolicyName; @@ -865,7 +869,6 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu } var order = methodConfiguration.Order ?? classConfiguration.Order; - var endpointGroupName = methodConfiguration.EndpointGroupName ?? classConfiguration.EndpointGroupName; return new EndpointConfiguration { @@ -895,7 +898,9 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu WithRequestTimeout = withRequestTimeout, RequestTimeoutPolicyName = requestTimeoutPolicyName, Order = order, - EndpointGroupName = endpointGroupName, + GroupIdentifier = groupIdentifier, + GroupPattern = groupPattern, + GroupName = groupName, }; } diff --git a/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj b/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj index 3f4e8b2..9078773 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj +++ b/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj @@ -45,6 +45,15 @@ IndividualTests.cs + + GeneratedSourceTests.cs + + + GeneratedSourceTests.cs + + + GeneratedSourceTests.cs + diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.cs b/tests/GeneratedEndpoints.Tests/IndividualTests.cs index e6c0265..c2dcfe6 100644 --- a/tests/GeneratedEndpoints.Tests/IndividualTests.cs +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.cs @@ -172,13 +172,6 @@ public async Task OrderMetadata() await VerifyIndividualAsync(source, nameof(OrderMetadata)); } - [Fact] - public async Task GroupName() - { - var source = AuthorizationScenario(groupName: "IndividualGroup"); - await VerifyIndividualAsync(source, nameof(GroupName)); - } - [Fact] public async Task ClassMapGroup() { From 4803be215d3026ea22d59e6ca54be1ce6f5b584a Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 22:14:40 -0500 Subject: [PATCH 11/32] Add MapGroup individual test (#62) --- tests/GeneratedEndpoints.Tests/IndividualTests.cs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.cs b/tests/GeneratedEndpoints.Tests/IndividualTests.cs index c2dcfe6..fbeff64 100644 --- a/tests/GeneratedEndpoints.Tests/IndividualTests.cs +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.cs @@ -194,6 +194,13 @@ public async Task ClassMapGroup() await VerifyIndividualAsync(source, nameof(ClassMapGroup)); } + [Fact] + public async Task GroupName() + { + var source = AuthorizationScenario(groupName: "IndividualGroup"); + await VerifyIndividualAsync(source, nameof(GroupName)); + } + [Fact] public async Task ExcludeFromDescription() { From ab6c288fa2be415a6aa05e20b652f41f3afc03c5 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 22:16:26 -0500 Subject: [PATCH 12/32] Refactored. --- src/GeneratedEndpoints/MinimalApiGenerator.cs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index dab9a41..fb0fa3e 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -37,7 +37,7 @@ public void Initialize(IncrementalGeneratorInitializationContext context) requestHandlerProviders.Add(handlers); } - var requestHandlers = CombineRequestHandlers(requestHandlerProviders.MoveToImmutable()); + var requestHandlers = CombineRequestHandlers(requestHandlerProviders.MoveToImmutable()).Select((x, _) => x.ToEquatableImmutableArray()); context.RegisterSourceOutput(requestHandlers, GenerateSource); } @@ -174,7 +174,7 @@ private static CompilationTypeCache GetCompilationTypeCache(Compilation compilat return CompilationTypeCaches.GetValue(compilation, static c => new CompilationTypeCache(c)); } - private static void GenerateSource(SourceProductionContext context, ImmutableArray requestHandlers) + private static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) { context.CancellationToken.ThrowIfCancellationRequested(); @@ -185,10 +185,10 @@ private static void GenerateSource(SourceProductionContext context, ImmutableArr GenerateUseEndpointHandlersClass(context, sorted); } - private static ImmutableArray SortRequestHandlers(ImmutableArray requestHandlers) + private static ImmutableArray SortRequestHandlers(EquatableImmutableArray requestHandlers) { - if (requestHandlers.Length <= 1) - return requestHandlers; + if (requestHandlers.Count <= 1) + return [..requestHandlers]; var array = requestHandlers.ToArray(); Array.Sort(array, RequestHandlerComparer.Instance); From e5b3e3748480645e492ee79488962e380b8cbce6 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 22:25:37 -0500 Subject: [PATCH 13/32] Refactor. --- .../AddEndpointHandlersGenerator.cs | 106 +++ src/GeneratedEndpoints/MinimalApiGenerator.cs | 754 +----------------- .../UseEndpointHandlersGenerator.cs | 662 +++++++++++++++ 3 files changed, 779 insertions(+), 743 deletions(-) create mode 100644 src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs create mode 100644 src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs diff --git a/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs new file mode 100644 index 0000000..b69c6fc --- /dev/null +++ b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs @@ -0,0 +1,106 @@ +using System.Collections.Immutable; +using System.Text; +using GeneratedEndpoints.Common; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; +using static GeneratedEndpoints.Common.Constants; + +namespace GeneratedEndpoints; + +// ReSharper disable ForCanBeConvertedToForeach +// ReSharper disable LoopCanBeConvertedToQuery +// Do not refactor, use for loop to avoid allocations. + +internal static class AddEndpointHandlersGenerator +{ + public static void GenerateSource(SourceProductionContext context, ImmutableArray requestHandlers) + { + context.CancellationToken.ThrowIfCancellationRequested(); + + var nonStaticClassNames = GetDistinctNonStaticClassNames(requestHandlers); + var source = GetAddEndpointHandlersStringBuilder(nonStaticClassNames); + source.AppendLine(FileHeader); + + source.AppendLine(); + + source.AppendLine("using Microsoft.Extensions.DependencyInjection;"); + source.AppendLine("using Microsoft.Extensions.DependencyInjection.Extensions;"); + source.AppendLine(); + + source.Append("namespace "); + source.Append(RoutingNamespace); + source.AppendLine(";"); + + source.AppendLine(); + + source.Append("internal static class "); + source.Append(AddEndpointHandlersClassName); + source.AppendLine(); + + source.AppendLine("{"); + + source.Append(" internal static void "); + source.Append(AddEndpointHandlersMethodName); + source.AppendLine("(this IServiceCollection services)"); + + source.AppendLine(" {"); + + foreach (var className in nonStaticClassNames) + { + source.Append(" services.TryAddScoped<"); + source.Append(className); + source.Append(">();"); + source.AppendLine(); + } + + source.AppendLine(""" + } + } + """ + ); + + var sourceText = StringBuilderPool.ToStringAndReturn(source); + context.AddSource(AddEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); + } + + private static List GetDistinctNonStaticClassNames(ImmutableArray requestHandlers) + { + var classNames = new List(); + if (requestHandlers.IsDefaultOrEmpty) + return classNames; + + var seen = new HashSet(StringComparer.Ordinal); + for (var index = 0; index < requestHandlers.Length; index++) + { + var requestHandler = requestHandlers[index]; + if (requestHandler.Class.IsStatic) + continue; + + var className = requestHandler.Class.Name; + if (seen.Add(className)) + classNames.Add(className); + } + + return classNames; + } + + private static StringBuilder GetAddEndpointHandlersStringBuilder(List nonStaticClassNames) + { + var estimate = 512L; + for (var index = 0; index < nonStaticClassNames.Count; index++) + { + var className = nonStaticClassNames[index]; + estimate += 36 + className.Length; + } + + estimate += Math.Max(256, nonStaticClassNames.Count * 12); + estimate = (long)(estimate * 1.10); + + if (estimate < 512) + estimate = 512; + else if (estimate > int.MaxValue) + estimate = int.MaxValue; + + return StringBuilderPool.Get((int)estimate); + } +} diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index fb0fa3e..608f931 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -1,11 +1,9 @@ using System.Buffers; using System.Collections.Immutable; using System.Runtime.CompilerServices; -using System.Text; using GeneratedEndpoints.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Text; using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints; @@ -37,7 +35,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) requestHandlerProviders.Add(handlers); } - var requestHandlers = CombineRequestHandlers(requestHandlerProviders.MoveToImmutable()).Select((x, _) => x.ToEquatableImmutableArray()); + var requestHandlers = CombineRequestHandlers(requestHandlerProviders.MoveToImmutable()) + .Select((x, _) => x.ToEquatableImmutableArray()); context.RegisterSourceOutput(requestHandlers, GenerateSource); } @@ -178,11 +177,18 @@ private static void GenerateSource(SourceProductionContext context, EquatableImm { context.CancellationToken.ThrowIfCancellationRequested(); + var normalized = NormalizeRequestHandlers(requestHandlers); + + AddEndpointHandlersGenerator.GenerateSource(context, normalized); + UseEndpointHandlersGenerator.GenerateSource(context, normalized); + } + + private static ImmutableArray NormalizeRequestHandlers(EquatableImmutableArray requestHandlers) + { var sorted = SortRequestHandlers(requestHandlers); sorted = EnsureUniqueEndpointNames(sorted); - GenerateAddEndpointHandlersClass(context, sorted); - GenerateUseEndpointHandlersClass(context, sorted); + return sorted; } private static ImmutableArray SortRequestHandlers(EquatableImmutableArray requestHandlers) @@ -286,742 +292,4 @@ private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestH return string.Concat(className, ".", requestHandler.Method.Name); } - - private static void GenerateAddEndpointHandlersClass(SourceProductionContext context, ImmutableArray requestHandlers) - { - context.CancellationToken.ThrowIfCancellationRequested(); - - var nonStaticClassNames = GetDistinctNonStaticClassNames(requestHandlers); - var source = GetAddEndpointHandlersStringBuilder(nonStaticClassNames); - source.AppendLine(FileHeader); - - source.AppendLine(); - - source.AppendLine("using Microsoft.Extensions.DependencyInjection;"); - source.AppendLine("using Microsoft.Extensions.DependencyInjection.Extensions;"); - source.AppendLine(); - - source.Append("namespace "); - source.Append(RoutingNamespace); - source.AppendLine(";"); - - source.AppendLine(); - - source.Append("internal static class "); - source.Append(AddEndpointHandlersClassName); - source.AppendLine(); - - source.AppendLine("{"); - - source.Append(" internal static void "); - source.Append(AddEndpointHandlersMethodName); - source.AppendLine("(this IServiceCollection services)"); - - source.AppendLine(" {"); - - foreach (var className in nonStaticClassNames) - { - source.Append(" services.TryAddScoped<"); - source.Append(className); - source.Append(">();"); - source.AppendLine(); - } - - source.AppendLine(""" - } - } - """ - ); - - var sourceText = StringBuilderPool.ToStringAndReturn(source); - context.AddSource(AddEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); - } - - private static List GetDistinctNonStaticClassNames(ImmutableArray requestHandlers) - { - var classNames = new List(); - if (requestHandlers.IsDefaultOrEmpty) - return classNames; - - var seen = new HashSet(StringComparer.Ordinal); - for (var index = 0; index < requestHandlers.Length; index++) - { - var requestHandler = requestHandlers[index]; - if (requestHandler.Class.IsStatic) - continue; - - var className = requestHandler.Class.Name; - if (seen.Add(className)) - classNames.Add(className); - } - - return classNames; - } - - private static StringBuilder GetAddEndpointHandlersStringBuilder(List nonStaticClassNames) - { - var estimate = 512L; - for (var index = 0; index < nonStaticClassNames.Count; index++) - { - var className = nonStaticClassNames[index]; - estimate += 36 + className.Length; - } - - estimate += Math.Max(256, nonStaticClassNames.Count * 12); - estimate = (long)(estimate * 1.10); - - if (estimate < 512) - estimate = 512; - else if (estimate > int.MaxValue) - estimate = int.MaxValue; - - return StringBuilderPool.Get((int)estimate); - } - - private static void GenerateUseEndpointHandlersClass(SourceProductionContext context, ImmutableArray requestHandlers) - { - context.CancellationToken.ThrowIfCancellationRequested(); - - var source = GetUseEndpointHandlersStringBuilder(requestHandlers); - source.AppendLine(FileHeader); - - source.AppendLine(); - - source.AppendLine("using Microsoft.AspNetCore.Builder;"); - source.AppendLine("using Microsoft.AspNetCore.Http;"); - source.AppendLine("using Microsoft.AspNetCore.Mvc;"); - source.AppendLine("using Microsoft.AspNetCore.Routing;"); - if (HasRateLimitedHandlers(requestHandlers)) - source.AppendLine("using Microsoft.AspNetCore.RateLimiting;"); - source.AppendLine("using Microsoft.Extensions.DependencyInjection;"); - source.AppendLine(); - - source.Append("namespace "); - source.Append(RoutingNamespace); - source.AppendLine(";"); - - source.AppendLine(); - - source.Append("internal static class "); - source.Append(UseEndpointHandlersClassName); - source.AppendLine(); - - source.AppendLine("{"); - - source.Append(" internal static IEndpointRouteBuilder "); - source.Append(UseEndpointHandlersMethodName); - source.AppendLine("(this IEndpointRouteBuilder builder)"); - - source.AppendLine(" {"); - - var groupedClasses = GetClassesWithMapGroups(requestHandlers); - - for (var index = 0; index < groupedClasses.Count; index++) - { - var groupedClass = groupedClasses[index]; - source.Append(" var "); - source.Append(groupedClass.Configuration.GroupIdentifier); - source.Append(" = builder.MapGroup("); - source.Append(groupedClass.Configuration.GroupPattern!.ToStringLiteral()); - source.Append(')'); - AppendEndpointConfiguration(source, " ", groupedClass.Configuration, false); - source.AppendLine(";"); - } - - if (groupedClasses.Count > 0) - source.AppendLine(); - - for (var index = 0; index < requestHandlers.Length; index++) - { - if (index > 0) - source.AppendLine(); - - var requestHandler = requestHandlers[index]; - GenerateMapRequestHandler(source, requestHandler); - } - - source.AppendLine(""" - - return builder; - } - } - """ - ); - - var sourceText = StringBuilderPool.ToStringAndReturn(source); - context.AddSource(UseEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); - } - - private static bool HasRateLimitedHandlers(ImmutableArray requestHandlers) - { - for (var index = 0; index < requestHandlers.Length; index++) - { - var handler = requestHandlers[index]; - if (handler.Configuration.RequireRateLimiting) - return true; - } - - return false; - } - - private static List GetClassesWithMapGroups(ImmutableArray requestHandlers) - { - var groupedClasses = new List(); - if (requestHandlers.IsDefaultOrEmpty) - return groupedClasses; - - var seen = new HashSet(StringComparer.Ordinal); - for (var index = 0; index < requestHandlers.Length; index++) - { - var handler = requestHandlers[index]; - var handlerClass = handler.Class; - if (handlerClass.Configuration.GroupPattern is null) - continue; - - if (seen.Add(handlerClass.Name)) - groupedClasses.Add(handlerClass); - } - - return groupedClasses; - } - - private static void GenerateMapRequestHandler(StringBuilder source, RequestHandler requestHandler) - { - var wrapWithConfigure = requestHandler.Class.HasConfigureMethod; - var configureAcceptsServiceProvider = requestHandler.Class.ConfigureMethodAcceptsServiceProvider; - var indent = wrapWithConfigure ? " " : " "; - var continuationIndent = indent + " "; - var routeBuilderIdentifier = requestHandler.Class.Configuration.GroupIdentifier ?? "builder"; - - if (wrapWithConfigure) - { - source.Append(" "); - source.Append(requestHandler.Class.Name); - source.Append('.'); - source.Append(ConfigureMethodName); - source.AppendLine("("); - } - - var isFallback = string.Equals(requestHandler.HttpMethod, FallbackHttpMethod, StringComparison.Ordinal); - var mapMethodSuffix = isFallback ? null : GetMapMethodSuffix(requestHandler.HttpMethod); - - source.Append(indent); - if (isFallback) - { - source.Append(routeBuilderIdentifier); - source.Append(".MapFallback("); - if (!string.IsNullOrEmpty(requestHandler.Pattern)) - { - source.Append(requestHandler.Pattern.ToStringLiteral()); - source.Append(", "); - } - } - else - { - source.Append(routeBuilderIdentifier); - source.Append(".Map"); - source.Append(mapMethodSuffix ?? "Methods"); - source.Append('('); - source.Append(requestHandler.Pattern.ToStringLiteral()); - source.Append(", "); - if (mapMethodSuffix is null) - { - source.Append("new[] { \""); - source.Append(requestHandler.HttpMethod); - source.Append("\" }, "); - } - } - if (requestHandler.Method.IsStatic) - { - source.Append(requestHandler.Class.Name); - source.Append('.'); - source.Append(requestHandler.Method.Name); - } - else - { - source.Append("static ([FromServices] "); - source.Append(requestHandler.Class.Name); - source.Append(" handler"); - foreach (var parameter in requestHandler.Method.Parameters) - { - source.Append(", "); - source.Append(parameter.BindingPrefix); - source.Append(parameter.Type); - source.Append(' '); - source.Append(parameter.Name); - } - source.Append(") => handler."); - source.Append(requestHandler.Method.Name); - source.Append('('); - for (var index = 0; index < requestHandler.Method.Parameters.Count; index++) - { - if (index > 0) - source.Append(", "); - var parameter = requestHandler.Method.Parameters[index]; - source.Append(parameter.Name); - } - source.Append(')'); - } - source.Append(')'); - - var configuration = requestHandler.Configuration; - if (requestHandler.Class.Configuration.GroupPattern is null) - configuration = MergeEndpointConfigurations(requestHandler.Class.Configuration, configuration); - - AppendEndpointConfiguration(source, continuationIndent, configuration, true); - - if (wrapWithConfigure && configureAcceptsServiceProvider) - { - source.AppendLine(","); - source.Append(indent); - source.Append("builder.ServiceProvider"); - } - - if (wrapWithConfigure) - { - source.AppendLine(); - source.Append(" );"); - source.AppendLine(); - } - else - { - source.AppendLine(";"); - } - } - - private static void AppendEndpointConfiguration(StringBuilder source, string indent, EndpointConfiguration configuration, bool includeNameAndDisplayName) - { - if (includeNameAndDisplayName && !string.IsNullOrEmpty(configuration.Name)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithName("); - source.Append(configuration.Name.ToStringLiteral()); - source.Append(')'); - } - - if (includeNameAndDisplayName && !string.IsNullOrEmpty(configuration.DisplayName)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithDisplayName("); - source.Append(configuration.DisplayName.ToStringLiteral()); - source.Append(')'); - } - - if (!string.IsNullOrEmpty(configuration.Summary)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithSummary("); - source.Append(configuration.Summary.ToStringLiteral()); - source.Append(')'); - } - - if (!string.IsNullOrEmpty(configuration.Description)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithDescription("); - source.Append(configuration.Description.ToStringLiteral()); - source.Append(')'); - } - - if (!string.IsNullOrEmpty(configuration.GroupName)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithGroupName("); - source.Append(configuration.GroupName.ToStringLiteral()); - source.Append(')'); - } - - if (configuration.Order is { } orderValue) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithOrder("); - source.Append(orderValue); - source.Append(')'); - } - - if (configuration.ExcludeFromDescription) - { - source.AppendLine(); - source.Append(indent); - source.Append(".ExcludeFromDescription()"); - } - - if (configuration.Tags is { Count: > 0 }) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithTags("); - AppendCommaSeparatedLiterals(source, configuration.Tags.Value); - source.Append(')'); - } - - if (configuration.Accepts is { Count: > 0 }) - foreach (var accepts in configuration.Accepts.Value) - { - source.AppendLine(); - source.Append(indent); - source.Append(".Accepts<"); - source.Append(accepts.RequestType); - source.Append('>'); - source.Append('('); - if (accepts.IsOptional) - source.Append("isOptional: true, "); - source.Append(accepts.ContentType.ToStringLiteral()); - AppendAdditionalContentTypes(source, accepts.AdditionalContentTypes); - source.Append(')'); - } - - if (configuration.Produces is { Count: > 0 }) - foreach (var produces in configuration.Produces.Value) - { - source.AppendLine(); - source.Append(indent); - source.Append(".Produces<"); - source.Append(produces.ResponseType); - source.Append('>'); - source.Append('('); - source.Append(produces.StatusCode); - AppendOptionalContentTypes(source, produces.ContentType, produces.AdditionalContentTypes); - source.Append(')'); - } - - if (configuration.ProducesProblem is { Count: > 0 }) - foreach (var producesProblem in configuration.ProducesProblem.Value) - { - source.AppendLine(); - source.Append(indent); - source.Append(".ProducesProblem("); - source.Append(producesProblem.StatusCode); - AppendOptionalContentTypes(source, producesProblem.ContentType, producesProblem.AdditionalContentTypes); - source.Append(')'); - } - - if (configuration.ProducesValidationProblem is { Count: > 0 }) - foreach (var producesValidationProblem in configuration.ProducesValidationProblem.Value) - { - source.AppendLine(); - source.Append(indent); - source.Append(".ProducesValidationProblem("); - source.Append(producesValidationProblem.StatusCode); - AppendOptionalContentTypes(source, producesValidationProblem.ContentType, producesValidationProblem.AdditionalContentTypes); - source.Append(')'); - } - - if (configuration.RequireAuthorization) - { - source.AppendLine(); - if (configuration.AuthorizationPolicies is { Count: > 0 }) - { - source.Append(indent); - source.Append(".RequireAuthorization("); - AppendCommaSeparatedLiterals(source, configuration.AuthorizationPolicies.Value); - source.Append(')'); - } - else - { - source.Append(indent); - source.Append(".RequireAuthorization()"); - } - } - - if (configuration.RequireCors) - { - source.AppendLine(); - source.Append(indent); - if (!string.IsNullOrEmpty(configuration.CorsPolicyName)) - { - source.Append(".RequireCors("); - source.Append(configuration.CorsPolicyName.ToStringLiteral()); - source.Append(')'); - } - else - { - source.Append(".RequireCors()"); - } - } - - if (configuration.RequiredHosts is { Count: > 0 }) - { - source.AppendLine(); - source.Append(indent); - source.Append(".RequireHost("); - AppendCommaSeparatedLiterals(source, configuration.RequiredHosts.Value); - source.Append(')'); - } - - if (configuration.RequireRateLimiting && !string.IsNullOrEmpty(configuration.RateLimitingPolicyName)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".RequireRateLimiting("); - source.Append(configuration.RateLimitingPolicyName.ToStringLiteral()); - source.Append(')'); - } - - if (configuration.DisableAntiforgery) - { - source.AppendLine(); - source.Append(indent); - source.Append(".DisableAntiforgery()"); - } - - if (configuration.AllowAnonymous) - { - source.AppendLine(); - source.Append(indent); - source.Append(".AllowAnonymous()"); - } - - if (configuration.ShortCircuit) - { - source.AppendLine(); - source.Append(indent); - source.Append(".ShortCircuit()"); - } - - if (configuration.DisableValidation) - { - source.AppendLine(); - source.Append(indent); - source.Append(".DisableValidation()"); - source.AppendLine(); - } - - if (configuration.DisableRequestTimeout) - { - source.AppendLine(); - source.Append(indent); - source.Append(".DisableRequestTimeout()"); - } - else if (configuration.WithRequestTimeout) - { - source.AppendLine(); - source.Append(indent); - if (!string.IsNullOrEmpty(configuration.RequestTimeoutPolicyName)) - { - source.Append(".WithRequestTimeout("); - source.Append(configuration.RequestTimeoutPolicyName.ToStringLiteral()); - source.Append(')'); - } - else - { - source.Append(".WithRequestTimeout()"); - } - } - - if (configuration.EndpointFilterTypes is { Count: > 0 }) - foreach (var filterType in configuration.EndpointFilterTypes.Value) - { - source.AppendLine(); - source.Append(indent); - source.Append(".AddEndpointFilter<"); - source.Append(filterType); - source.Append(">()"); - } - } - - private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfiguration classConfiguration, EndpointConfiguration methodConfiguration) - { - var name = methodConfiguration.Name ?? classConfiguration.Name; - var displayName = methodConfiguration.DisplayName ?? classConfiguration.DisplayName; - var summary = methodConfiguration.Summary ?? classConfiguration.Summary; - var description = methodConfiguration.Description ?? classConfiguration.Description; - var tags = MergeDistinctStrings(classConfiguration.Tags, methodConfiguration.Tags); - var accepts = ConcatEquatable(classConfiguration.Accepts, methodConfiguration.Accepts); - var produces = ConcatEquatable(classConfiguration.Produces, methodConfiguration.Produces); - var producesProblem = ConcatEquatable(classConfiguration.ProducesProblem, methodConfiguration.ProducesProblem); - var producesValidationProblem = ConcatEquatable(classConfiguration.ProducesValidationProblem, methodConfiguration.ProducesValidationProblem); - var excludeFromDescription = classConfiguration.ExcludeFromDescription || methodConfiguration.ExcludeFromDescription; - var authorizationPolicies = MergeDistinctStrings(classConfiguration.AuthorizationPolicies, methodConfiguration.AuthorizationPolicies); - var requiredHosts = MergeDistinctStrings(classConfiguration.RequiredHosts, methodConfiguration.RequiredHosts); - var endpointFilterTypes = ConcatEquatable(classConfiguration.EndpointFilterTypes, methodConfiguration.EndpointFilterTypes); - var requireAuthorization = classConfiguration.RequireAuthorization || methodConfiguration.RequireAuthorization; - var disableAntiforgery = classConfiguration.DisableAntiforgery || methodConfiguration.DisableAntiforgery; - var allowAnonymous = classConfiguration.AllowAnonymous || methodConfiguration.AllowAnonymous; - var requireCors = classConfiguration.RequireCors || methodConfiguration.RequireCors; - var corsPolicyName = methodConfiguration.CorsPolicyName ?? classConfiguration.CorsPolicyName; - var requireRateLimiting = classConfiguration.RequireRateLimiting || methodConfiguration.RequireRateLimiting; - var rateLimitingPolicyName = methodConfiguration.RateLimitingPolicyName ?? classConfiguration.RateLimitingPolicyName; - var shortCircuit = classConfiguration.ShortCircuit || methodConfiguration.ShortCircuit; - var disableValidation = classConfiguration.DisableValidation || methodConfiguration.DisableValidation; - var disableRequestTimeout = classConfiguration.DisableRequestTimeout || methodConfiguration.DisableRequestTimeout; - var withRequestTimeout = classConfiguration.WithRequestTimeout || methodConfiguration.WithRequestTimeout; - var groupIdentifier = classConfiguration.GroupIdentifier ?? methodConfiguration.GroupIdentifier; - var groupPattern = classConfiguration.GroupPattern ?? methodConfiguration.GroupPattern; - var groupName = classConfiguration.GroupName ?? methodConfiguration.GroupName; - - string? requestTimeoutPolicyName = null; - if (methodConfiguration.WithRequestTimeout) - requestTimeoutPolicyName = methodConfiguration.RequestTimeoutPolicyName; - else if (classConfiguration.WithRequestTimeout) - requestTimeoutPolicyName = classConfiguration.RequestTimeoutPolicyName; - - if (disableRequestTimeout) - { - withRequestTimeout = false; - requestTimeoutPolicyName = null; - } - - var order = methodConfiguration.Order ?? classConfiguration.Order; - - return new EndpointConfiguration - { - Name = name, - DisplayName = displayName, - Summary = summary, - Description = description, - Tags = tags, - Accepts = accepts, - Produces = produces, - ProducesProblem = producesProblem, - ProducesValidationProblem = producesValidationProblem, - ExcludeFromDescription = excludeFromDescription, - RequireAuthorization = requireAuthorization, - AuthorizationPolicies = authorizationPolicies, - DisableAntiforgery = disableAntiforgery, - AllowAnonymous = allowAnonymous, - RequireCors = requireCors, - CorsPolicyName = corsPolicyName, - RequiredHosts = requiredHosts, - RequireRateLimiting = requireRateLimiting, - RateLimitingPolicyName = rateLimitingPolicyName, - EndpointFilterTypes = endpointFilterTypes, - ShortCircuit = shortCircuit, - DisableValidation = disableValidation, - DisableRequestTimeout = disableRequestTimeout, - WithRequestTimeout = withRequestTimeout, - RequestTimeoutPolicyName = requestTimeoutPolicyName, - Order = order, - GroupIdentifier = groupIdentifier, - GroupPattern = groupPattern, - GroupName = groupName, - }; - } - - private static EquatableImmutableArray? MergeDistinctStrings(EquatableImmutableArray? first, EquatableImmutableArray? second) - { - if (first is not { Count: > 0 }) - return second; - if (second is not { Count: > 0 }) - return first; - - var merged = MergeUnion(first, second.Value); - return merged.Count > 0 ? merged : null; - } - - private static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, EquatableImmutableArray values) - { - List? list = null; - HashSet? seen = null; - - if (existing is { Count: > 0 }) - { - var count = existing.Value.Count; - list = new List(count + 4); - list.AddRange(existing.Value); - seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); - } - - foreach (var value in values) - { - var normalized = value.NormalizeOptionalString(); - if (normalized is not { Length: > 0 }) - continue; - - seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); - if (!seen.Add(normalized)) - continue; - - list ??= []; - list.Add(normalized); - } - - return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; - } - - private static EquatableImmutableArray? ConcatEquatable(EquatableImmutableArray? first, EquatableImmutableArray? second) - { - if (first is not { Count: > 0 }) - return second; - if (second is not { Count: > 0 }) - return first; - - var builder = ImmutableArray.CreateBuilder(first.Value.Count + second.Value.Count); - builder.AddRange(first.Value); - builder.AddRange(second.Value); - return builder.ToEquatableImmutableArray(); - } - - private static string? GetMapMethodSuffix(string httpMethod) - { - return httpMethod switch - { - "GET" => "Get", - "POST" => "Post", - "PUT" => "Put", - "DELETE" => "Delete", - "PATCH" => "Patch", - _ => null, - }; - } - - private static StringBuilder GetUseEndpointHandlersStringBuilder(ImmutableArray requestHandlers) - { - const int baseSize = 4096; - const int perHandler = 512; - - var handlerCount = Math.Max(requestHandlers.Length, 0); - var estimate = baseSize + (long)perHandler * handlerCount; - estimate = (long)(estimate * 1.10); - - if (estimate > int.MaxValue) - estimate = int.MaxValue; - - return StringBuilderPool.Get((int)Math.Max(baseSize, estimate)); - } - - private static void AppendAdditionalContentTypes(StringBuilder source, EquatableImmutableArray? additionalContentTypes) - { - if (additionalContentTypes is not { Count: > 0 }) - return; - - foreach (var additional in additionalContentTypes.Value) - { - source.Append(", "); - source.Append(additional.ToStringLiteral()); - } - } - - private static void AppendCommaSeparatedLiterals(StringBuilder source, EquatableImmutableArray values) - { - if (values.Count == 0) - return; - - source.Append(values[0] - .ToStringLiteral() - ); - for (var i = 1; i < values.Count; i++) - { - source.Append(", "); - source.Append(values[i] - .ToStringLiteral() - ); - } - } - - private static void AppendOptionalContentTypes(StringBuilder source, string? contentType, EquatableImmutableArray? additionalContentTypes) - { - if (string.IsNullOrEmpty(contentType) && additionalContentTypes is not { Count: > 0 }) - return; - - source.Append(", "); - source.Append(contentType is { Length: > 0 } ? contentType.ToStringLiteral() : "null"); - AppendAdditionalContentTypes(source, additionalContentTypes); - } } diff --git a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs new file mode 100644 index 0000000..c1c0715 --- /dev/null +++ b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs @@ -0,0 +1,662 @@ +using System.Collections.Immutable; +using System.Text; +using GeneratedEndpoints.Common; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; +using static GeneratedEndpoints.Common.Constants; + +namespace GeneratedEndpoints; + +// ReSharper disable ForCanBeConvertedToForeach +// ReSharper disable LoopCanBeConvertedToQuery +// Do not refactor, use for loop to avoid allocations. + +internal static class UseEndpointHandlersGenerator +{ + public static void GenerateSource(SourceProductionContext context, ImmutableArray requestHandlers) + { + context.CancellationToken.ThrowIfCancellationRequested(); + + var source = GetUseEndpointHandlersStringBuilder(requestHandlers); + source.AppendLine(FileHeader); + + source.AppendLine(); + + source.AppendLine("using Microsoft.AspNetCore.Builder;"); + source.AppendLine("using Microsoft.AspNetCore.Http;"); + source.AppendLine("using Microsoft.AspNetCore.Mvc;"); + source.AppendLine("using Microsoft.AspNetCore.Routing;"); + if (HasRateLimitedHandlers(requestHandlers)) + source.AppendLine("using Microsoft.AspNetCore.RateLimiting;"); + source.AppendLine("using Microsoft.Extensions.DependencyInjection;"); + source.AppendLine(); + + source.Append("namespace "); + source.Append(RoutingNamespace); + source.AppendLine(";"); + + source.AppendLine(); + + source.Append("internal static class "); + source.Append(UseEndpointHandlersClassName); + source.AppendLine(); + + source.AppendLine("{"); + + source.Append(" internal static IEndpointRouteBuilder "); + source.Append(UseEndpointHandlersMethodName); + source.AppendLine("(this IEndpointRouteBuilder builder)"); + + source.AppendLine(" {"); + + var groupedClasses = GetClassesWithMapGroups(requestHandlers); + + for (var index = 0; index < groupedClasses.Count; index++) + { + var groupedClass = groupedClasses[index]; + source.Append(" var "); + source.Append(groupedClass.Configuration.GroupIdentifier); + source.Append(" = builder.MapGroup("); + source.Append(groupedClass.Configuration.GroupPattern!.ToStringLiteral()); + source.Append(')'); + AppendEndpointConfiguration(source, " ", groupedClass.Configuration, false); + source.AppendLine(";"); + } + + if (groupedClasses.Count > 0) + source.AppendLine(); + + for (var index = 0; index < requestHandlers.Length; index++) + { + if (index > 0) + source.AppendLine(); + + var requestHandler = requestHandlers[index]; + GenerateMapRequestHandler(source, requestHandler); + } + + source.AppendLine(""" + + return builder; + } + } + """ + ); + + var sourceText = StringBuilderPool.ToStringAndReturn(source); + context.AddSource(UseEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); + } + + private static bool HasRateLimitedHandlers(ImmutableArray requestHandlers) + { + for (var index = 0; index < requestHandlers.Length; index++) + { + var handler = requestHandlers[index]; + if (handler.Configuration.RequireRateLimiting) + return true; + } + + return false; + } + + private static List GetClassesWithMapGroups(ImmutableArray requestHandlers) + { + var groupedClasses = new List(); + if (requestHandlers.IsDefaultOrEmpty) + return groupedClasses; + + var seen = new HashSet(StringComparer.Ordinal); + for (var index = 0; index < requestHandlers.Length; index++) + { + var handler = requestHandlers[index]; + var handlerClass = handler.Class; + if (handlerClass.Configuration.GroupPattern is null) + continue; + + if (seen.Add(handlerClass.Name)) + groupedClasses.Add(handlerClass); + } + + return groupedClasses; + } + + private static void GenerateMapRequestHandler(StringBuilder source, RequestHandler requestHandler) + { + var wrapWithConfigure = requestHandler.Class.HasConfigureMethod; + var configureAcceptsServiceProvider = requestHandler.Class.ConfigureMethodAcceptsServiceProvider; + var indent = wrapWithConfigure ? " " : " "; + var continuationIndent = indent + " "; + var routeBuilderIdentifier = requestHandler.Class.Configuration.GroupIdentifier ?? "builder"; + + if (wrapWithConfigure) + { + source.Append(" "); + source.Append(requestHandler.Class.Name); + source.Append('.'); + source.Append(ConfigureMethodName); + source.AppendLine("("); + } + + var isFallback = string.Equals(requestHandler.HttpMethod, FallbackHttpMethod, StringComparison.Ordinal); + var mapMethodSuffix = isFallback ? null : GetMapMethodSuffix(requestHandler.HttpMethod); + + source.Append(indent); + if (isFallback) + { + source.Append(routeBuilderIdentifier); + source.Append(".MapFallback("); + if (!string.IsNullOrEmpty(requestHandler.Pattern)) + { + source.Append(requestHandler.Pattern.ToStringLiteral()); + source.Append(", "); + } + } + else + { + source.Append(routeBuilderIdentifier); + source.Append(".Map"); + source.Append(mapMethodSuffix ?? "Methods"); + source.Append('('); + source.Append(requestHandler.Pattern.ToStringLiteral()); + source.Append(", "); + if (mapMethodSuffix is null) + { + source.Append("new[] { \""); + source.Append(requestHandler.HttpMethod); + source.Append("\" }, "); + } + } + if (requestHandler.Method.IsStatic) + { + source.Append(requestHandler.Class.Name); + source.Append('.'); + source.Append(requestHandler.Method.Name); + } + else + { + source.Append("static ([FromServices] "); + source.Append(requestHandler.Class.Name); + source.Append(" handler"); + foreach (var parameter in requestHandler.Method.Parameters) + { + source.Append(", "); + source.Append(parameter.BindingPrefix); + source.Append(parameter.Type); + source.Append(' '); + source.Append(parameter.Name); + } + source.Append(") => handler."); + source.Append(requestHandler.Method.Name); + source.Append('('); + for (var index = 0; index < requestHandler.Method.Parameters.Count; index++) + { + if (index > 0) + source.Append(", "); + var parameter = requestHandler.Method.Parameters[index]; + source.Append(parameter.Name); + } + source.Append(')'); + } + source.Append(')'); + + var configuration = requestHandler.Configuration; + if (requestHandler.Class.Configuration.GroupPattern is null) + configuration = MergeEndpointConfigurations(requestHandler.Class.Configuration, configuration); + + AppendEndpointConfiguration(source, continuationIndent, configuration, true); + + if (wrapWithConfigure && configureAcceptsServiceProvider) + { + source.AppendLine(","); + source.Append(indent); + source.Append("builder.ServiceProvider"); + } + + if (wrapWithConfigure) + { + source.AppendLine(); + source.Append(" );"); + source.AppendLine(); + } + else + { + source.AppendLine(";"); + } + } + + private static void AppendEndpointConfiguration(StringBuilder source, string indent, EndpointConfiguration configuration, bool includeNameAndDisplayName) + { + if (includeNameAndDisplayName && !string.IsNullOrEmpty(configuration.Name)) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithName("); + source.Append(configuration.Name.ToStringLiteral()); + source.Append(')'); + } + + if (includeNameAndDisplayName && !string.IsNullOrEmpty(configuration.DisplayName)) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithDisplayName("); + source.Append(configuration.DisplayName.ToStringLiteral()); + source.Append(')'); + } + + if (!string.IsNullOrEmpty(configuration.Summary)) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithSummary("); + source.Append(configuration.Summary.ToStringLiteral()); + source.Append(')'); + } + + if (!string.IsNullOrEmpty(configuration.Description)) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithDescription("); + source.Append(configuration.Description.ToStringLiteral()); + source.Append(')'); + } + + if (!string.IsNullOrEmpty(configuration.GroupName)) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithGroupName("); + source.Append(configuration.GroupName.ToStringLiteral()); + source.Append(')'); + } + + if (configuration.Order is { } orderValue) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithOrder("); + source.Append(orderValue); + source.Append(')'); + } + + if (configuration.ExcludeFromDescription) + { + source.AppendLine(); + source.Append(indent); + source.Append(".ExcludeFromDescription()"); + } + + if (configuration.Tags is { Count: > 0 }) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithTags("); + AppendCommaSeparatedLiterals(source, configuration.Tags.Value); + source.Append(')'); + } + + if (configuration.Accepts is { Count: > 0 }) + foreach (var accepts in configuration.Accepts.Value) + { + source.AppendLine(); + source.Append(indent); + source.Append(".Accepts<"); + source.Append(accepts.RequestType); + source.Append('>'); + source.Append('('); + if (accepts.IsOptional) + source.Append("isOptional: true, "); + source.Append(accepts.ContentType.ToStringLiteral()); + AppendAdditionalContentTypes(source, accepts.AdditionalContentTypes); + source.Append(')'); + } + + if (configuration.Produces is { Count: > 0 }) + foreach (var produces in configuration.Produces.Value) + { + source.AppendLine(); + source.Append(indent); + source.Append(".Produces<"); + source.Append(produces.ResponseType); + source.Append('>'); + source.Append('('); + source.Append(produces.StatusCode); + AppendOptionalContentTypes(source, produces.ContentType, produces.AdditionalContentTypes); + source.Append(')'); + } + + if (configuration.ProducesProblem is { Count: > 0 }) + foreach (var producesProblem in configuration.ProducesProblem.Value) + { + source.AppendLine(); + source.Append(indent); + source.Append(".ProducesProblem("); + source.Append(producesProblem.StatusCode); + AppendOptionalContentTypes(source, producesProblem.ContentType, producesProblem.AdditionalContentTypes); + source.Append(')'); + } + + if (configuration.ProducesValidationProblem is { Count: > 0 }) + foreach (var producesValidationProblem in configuration.ProducesValidationProblem.Value) + { + source.AppendLine(); + source.Append(indent); + source.Append(".ProducesValidationProblem("); + source.Append(producesValidationProblem.StatusCode); + AppendOptionalContentTypes(source, producesValidationProblem.ContentType, producesValidationProblem.AdditionalContentTypes); + source.Append(')'); + } + + if (configuration.RequireAuthorization) + { + source.AppendLine(); + if (configuration.AuthorizationPolicies is { Count: > 0 }) + { + source.Append(indent); + source.Append(".RequireAuthorization("); + AppendCommaSeparatedLiterals(source, configuration.AuthorizationPolicies.Value); + source.Append(')'); + } + else + { + source.Append(indent); + source.Append(".RequireAuthorization()"); + } + } + + if (configuration.RequireCors) + { + source.AppendLine(); + source.Append(indent); + if (!string.IsNullOrEmpty(configuration.CorsPolicyName)) + { + source.Append(".RequireCors("); + source.Append(configuration.CorsPolicyName.ToStringLiteral()); + source.Append(')'); + } + else + { + source.Append(".RequireCors()"); + } + } + + if (configuration.RequiredHosts is { Count: > 0 }) + { + source.AppendLine(); + source.Append(indent); + source.Append(".RequireHost("); + AppendCommaSeparatedLiterals(source, configuration.RequiredHosts.Value); + source.Append(')'); + } + + if (configuration.RequireRateLimiting && !string.IsNullOrEmpty(configuration.RateLimitingPolicyName)) + { + source.AppendLine(); + source.Append(indent); + source.Append(".RequireRateLimiting("); + source.Append(configuration.RateLimitingPolicyName.ToStringLiteral()); + source.Append(')'); + } + + if (configuration.DisableAntiforgery) + { + source.AppendLine(); + source.Append(indent); + source.Append(".DisableAntiforgery()"); + } + + if (configuration.AllowAnonymous) + { + source.AppendLine(); + source.Append(indent); + source.Append(".AllowAnonymous()"); + } + + if (configuration.ShortCircuit) + { + source.AppendLine(); + source.Append(indent); + source.Append(".ShortCircuit()"); + } + + if (configuration.DisableValidation) + { + source.AppendLine(); + source.Append(indent); + source.Append(".DisableValidation()"); + source.AppendLine(); + } + + if (configuration.DisableRequestTimeout) + { + source.AppendLine(); + source.Append(indent); + source.Append(".DisableRequestTimeout()"); + } + else if (configuration.WithRequestTimeout) + { + source.AppendLine(); + source.Append(indent); + if (!string.IsNullOrEmpty(configuration.RequestTimeoutPolicyName)) + { + source.Append(".WithRequestTimeout("); + source.Append(configuration.RequestTimeoutPolicyName.ToStringLiteral()); + source.Append(')'); + } + else + { + source.Append(".WithRequestTimeout()"); + } + } + + if (configuration.EndpointFilterTypes is { Count: > 0 }) + foreach (var filterType in configuration.EndpointFilterTypes.Value) + { + source.AppendLine(); + source.Append(indent); + source.Append(".AddEndpointFilter<"); + source.Append(filterType); + source.Append(">()"); + } + } + + private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfiguration classConfiguration, EndpointConfiguration methodConfiguration) + { + var name = methodConfiguration.Name ?? classConfiguration.Name; + var displayName = methodConfiguration.DisplayName ?? classConfiguration.DisplayName; + var summary = methodConfiguration.Summary ?? classConfiguration.Summary; + var description = methodConfiguration.Description ?? classConfiguration.Description; + var tags = MergeDistinctStrings(classConfiguration.Tags, methodConfiguration.Tags); + var accepts = ConcatEquatable(classConfiguration.Accepts, methodConfiguration.Accepts); + var produces = ConcatEquatable(classConfiguration.Produces, methodConfiguration.Produces); + var producesProblem = ConcatEquatable(classConfiguration.ProducesProblem, methodConfiguration.ProducesProblem); + var producesValidationProblem = ConcatEquatable(classConfiguration.ProducesValidationProblem, methodConfiguration.ProducesValidationProblem); + var excludeFromDescription = classConfiguration.ExcludeFromDescription || methodConfiguration.ExcludeFromDescription; + var authorizationPolicies = MergeDistinctStrings(classConfiguration.AuthorizationPolicies, methodConfiguration.AuthorizationPolicies); + var requiredHosts = MergeDistinctStrings(classConfiguration.RequiredHosts, methodConfiguration.RequiredHosts); + var endpointFilterTypes = ConcatEquatable(classConfiguration.EndpointFilterTypes, methodConfiguration.EndpointFilterTypes); + var requireAuthorization = classConfiguration.RequireAuthorization || methodConfiguration.RequireAuthorization; + var disableAntiforgery = classConfiguration.DisableAntiforgery || methodConfiguration.DisableAntiforgery; + var allowAnonymous = classConfiguration.AllowAnonymous || methodConfiguration.AllowAnonymous; + var requireCors = classConfiguration.RequireCors || methodConfiguration.RequireCors; + var corsPolicyName = methodConfiguration.CorsPolicyName ?? classConfiguration.CorsPolicyName; + var requireRateLimiting = classConfiguration.RequireRateLimiting || methodConfiguration.RequireRateLimiting; + var rateLimitingPolicyName = methodConfiguration.RateLimitingPolicyName ?? classConfiguration.RateLimitingPolicyName; + var shortCircuit = classConfiguration.ShortCircuit || methodConfiguration.ShortCircuit; + var disableValidation = classConfiguration.DisableValidation || methodConfiguration.DisableValidation; + var disableRequestTimeout = classConfiguration.DisableRequestTimeout || methodConfiguration.DisableRequestTimeout; + var withRequestTimeout = classConfiguration.WithRequestTimeout || methodConfiguration.WithRequestTimeout; + var groupIdentifier = classConfiguration.GroupIdentifier ?? methodConfiguration.GroupIdentifier; + var groupPattern = classConfiguration.GroupPattern ?? methodConfiguration.GroupPattern; + var groupName = classConfiguration.GroupName ?? methodConfiguration.GroupName; + + string? requestTimeoutPolicyName = null; + if (methodConfiguration.WithRequestTimeout) + requestTimeoutPolicyName = methodConfiguration.RequestTimeoutPolicyName; + else if (classConfiguration.WithRequestTimeout) + requestTimeoutPolicyName = classConfiguration.RequestTimeoutPolicyName; + + if (disableRequestTimeout) + { + withRequestTimeout = false; + requestTimeoutPolicyName = null; + } + + var order = methodConfiguration.Order ?? classConfiguration.Order; + + return new EndpointConfiguration + { + Name = name, + DisplayName = displayName, + Summary = summary, + Description = description, + Tags = tags, + Accepts = accepts, + Produces = produces, + ProducesProblem = producesProblem, + ProducesValidationProblem = producesValidationProblem, + ExcludeFromDescription = excludeFromDescription, + RequireAuthorization = requireAuthorization, + AuthorizationPolicies = authorizationPolicies, + DisableAntiforgery = disableAntiforgery, + AllowAnonymous = allowAnonymous, + RequireCors = requireCors, + CorsPolicyName = corsPolicyName, + RequiredHosts = requiredHosts, + RequireRateLimiting = requireRateLimiting, + RateLimitingPolicyName = rateLimitingPolicyName, + EndpointFilterTypes = endpointFilterTypes, + ShortCircuit = shortCircuit, + DisableValidation = disableValidation, + DisableRequestTimeout = disableRequestTimeout, + WithRequestTimeout = withRequestTimeout, + RequestTimeoutPolicyName = requestTimeoutPolicyName, + Order = order, + GroupIdentifier = groupIdentifier, + GroupPattern = groupPattern, + GroupName = groupName, + }; + } + + private static EquatableImmutableArray? MergeDistinctStrings(EquatableImmutableArray? first, EquatableImmutableArray? second) + { + if (first is not { Count: > 0 }) + return second; + if (second is not { Count: > 0 }) + return first; + + var merged = MergeUnion(first, second.Value); + return merged.Count > 0 ? merged : null; + } + + private static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, EquatableImmutableArray values) + { + List? list = null; + HashSet? seen = null; + + if (existing is { Count: > 0 }) + { + var count = existing.Value.Count; + list = new List(count + 4); + list.AddRange(existing.Value); + seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); + } + + foreach (var value in values) + { + var normalized = value.NormalizeOptionalString(); + if (normalized is not { Length: > 0 }) + continue; + + seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); + if (!seen.Add(normalized)) + continue; + + list ??= []; + list.Add(normalized); + } + + return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; + } + + private static EquatableImmutableArray? ConcatEquatable(EquatableImmutableArray? first, EquatableImmutableArray? second) + { + if (first is not { Count: > 0 }) + return second; + if (second is not { Count: > 0 }) + return first; + + var builder = ImmutableArray.CreateBuilder(first.Value.Count + second.Value.Count); + builder.AddRange(first.Value); + builder.AddRange(second.Value); + return builder.ToEquatableImmutableArray(); + } + + private static string? GetMapMethodSuffix(string httpMethod) + { + return httpMethod switch + { + "GET" => "Get", + "POST" => "Post", + "PUT" => "Put", + "DELETE" => "Delete", + "PATCH" => "Patch", + _ => null, + }; + } + + private static StringBuilder GetUseEndpointHandlersStringBuilder(ImmutableArray requestHandlers) + { + const int baseSize = 4096; + const int perHandler = 512; + + var handlerCount = Math.Max(requestHandlers.Length, 0); + var estimate = baseSize + (long)perHandler * handlerCount; + estimate = (long)(estimate * 1.10); + + if (estimate > int.MaxValue) + estimate = int.MaxValue; + + return StringBuilderPool.Get((int)Math.Max(baseSize, estimate)); + } + + private static void AppendAdditionalContentTypes(StringBuilder source, EquatableImmutableArray? additionalContentTypes) + { + if (additionalContentTypes is not { Count: > 0 }) + return; + + foreach (var additional in additionalContentTypes.Value) + { + source.Append(", "); + source.Append(additional.ToStringLiteral()); + } + } + + private static void AppendCommaSeparatedLiterals(StringBuilder source, EquatableImmutableArray values) + { + if (values.Count == 0) + return; + + source.Append(values[0] + .ToStringLiteral() + ); + for (var i = 1; i < values.Count; i++) + { + source.Append(", "); + source.Append(values[i] + .ToStringLiteral() + ); + } + } + + private static void AppendOptionalContentTypes(StringBuilder source, string? contentType, EquatableImmutableArray? additionalContentTypes) + { + if (string.IsNullOrEmpty(contentType) && additionalContentTypes is not { Count: > 0 }) + return; + + source.Append(", "); + source.Append(contentType is { Length: > 0 } ? contentType.ToStringLiteral() : "null"); + AppendAdditionalContentTypes(source, additionalContentTypes); + } +} From f656ab8062ab07a415c06bd76b2335ba3b017a99 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 23:04:11 -0500 Subject: [PATCH 14/32] Optimize MinimalApiGenerator hot paths (#63) --- src/GeneratedEndpoints/MinimalApiGenerator.cs | 133 +++++++++++++----- 1 file changed, 100 insertions(+), 33 deletions(-) diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 608f931..fe72434 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -116,24 +116,28 @@ private static (string HttpMethod, string Pattern, string? Name) GetRequestHandl { cancellationToken.ThrowIfCancellationRequested(); - var attributeName = attribute.AttributeClass?.Name ?? ""; + var attributeName = attribute.AttributeClass?.Name ?? string.Empty; - var httpMethod = HttpAttributeDefinitionsByName.TryGetValue(attributeName, out var definition) ? definition.Verb : ""; + var httpMethod = HttpAttributeDefinitionsByName.TryGetValue(attributeName, out var definition) ? definition.Verb : string.Empty; - var pattern = (attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : "") ?? ""; + var pattern = attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : null; + pattern ??= string.Empty; string? name = null; - for (var index = 0; index < attribute.NamedArguments.Length; index++) + var namedArguments = attribute.NamedArguments; + if (!namedArguments.IsDefaultOrEmpty) { - var namedArg = attribute.NamedArguments[index]; - switch (namedArg.Key) + for (var i = 0; i < namedArguments.Length; i++) { - case NameAttributeNamedParameter: - { - var value = namedArg.Value.Value as string; - name = string.IsNullOrWhiteSpace(value) ? null : value!.Trim(); + var namedArg = namedArguments[i]; + if (namedArg.Key != NameAttributeNamedParameter) + continue; + + if (namedArg.Value.Value is not string value) break; - } + + name = string.IsNullOrWhiteSpace(value) ? null : value.Trim(); + break; } } @@ -193,12 +197,16 @@ private static ImmutableArray NormalizeRequestHandlers(Equatable private static ImmutableArray SortRequestHandlers(EquatableImmutableArray requestHandlers) { - if (requestHandlers.Count <= 1) - return [..requestHandlers]; - - var array = requestHandlers.ToArray(); - Array.Sort(array, RequestHandlerComparer.Instance); - return [..array]; + var count = requestHandlers.Count; + if (count == 0) + return ImmutableArray.Empty; + if (count == 1) + return ImmutableArray.Create(requestHandlers[0]); + + var builder = ImmutableArray.CreateBuilder(count); + builder.AddRange(requestHandlers); + builder.Sort(RequestHandlerComparer.Instance); + return builder.MoveToImmutable(); } private static ImmutableArray EnsureUniqueEndpointNames(ImmutableArray requestHandlers) @@ -208,8 +216,9 @@ private static ImmutableArray EnsureUniqueEndpointNames(Immutabl return requestHandlers; var builder = requestHandlers.ToBuilder(); - foreach (var index in collidingHandlers) + for (var i = 0; i < collidingHandlers.Length; i++) { + var index = collidingHandlers[i]; var handler = builder[index]; var configuration = handler.Configuration with { @@ -230,10 +239,11 @@ private static ImmutableArray GetRequestHandlersWithNameCollisions(Immutabl return ImmutableArray.Empty; var handlerCount = requestHandlers.Length; - var nameToFirstIndex = new Dictionary<(string Name, string Method), int>(handlerCount); + var nameToFirstIndex = new Dictionary(handlerCount); var collisionFlags = ArrayPool.Shared.Rent(handlerCount); Array.Clear(collisionFlags, 0, handlerCount); - List? collidingIndices = null; + int[]? collidingIndices = null; + var collidingCount = 0; try { @@ -244,7 +254,7 @@ private static ImmutableArray GetRequestHandlersWithNameCollisions(Immutabl if (string.IsNullOrEmpty(name)) continue; - var key = (name!, handler.Method.Name); + var key = new HandlerNameKey(name!, handler.Method.Name); if (nameToFirstIndex.TryGetValue(key, out var firstIndex)) { @@ -257,17 +267,20 @@ private static ImmutableArray GetRequestHandlersWithNameCollisions(Immutabl } } - if (collidingIndices is null || collidingIndices.Count == 0) + if (collidingIndices is null || collidingCount == 0) return ImmutableArray.Empty; - collidingIndices.Sort(); - var builder = ImmutableArray.CreateBuilder(collidingIndices.Count); - builder.AddRange(collidingIndices); + Array.Sort(collidingIndices, 0, collidingCount); + var builder = ImmutableArray.CreateBuilder(collidingCount); + for (var i = 0; i < collidingCount; i++) + builder.Add(collidingIndices[i]); return builder.MoveToImmutable(); } finally { ArrayPool.Shared.Return(collisionFlags); + if (collidingIndices is not null) + ArrayPool.Shared.Return(collidingIndices); } void MarkCollision(int handlerIndex) @@ -276,20 +289,74 @@ void MarkCollision(int handlerIndex) return; collisionFlags[handlerIndex] = true; - collidingIndices ??= []; - collidingIndices.Add(handlerIndex); + collidingIndices ??= ArrayPool.Shared.Rent(handlerCount); + collidingIndices[collidingCount++] = handlerIndex; } } private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestHandler) { - var className = requestHandler.Class.Name; - if (className.StartsWith(GlobalPrefix, StringComparison.Ordinal)) - className = className[GlobalPrefix.Length..]; + var className = requestHandler.Class.Name ?? string.Empty; + var methodName = requestHandler.Method.Name ?? string.Empty; + + var startIndex = className.StartsWith(GlobalPrefix, StringComparison.Ordinal) + ? GlobalPrefix.Length + : 0; + var length = className.Length - startIndex; + var containsNestedTypeSeparator = className.IndexOf('+', startIndex, length) >= 0; + + if (length == 0) + return string.Concat(".", methodName); + + var totalLength = length + 1 + methodName.Length; + var buffer = ArrayPool.Shared.Rent(totalLength); + try + { + var destinationIndex = 0; + for (var i = 0; i < length; i++) + { + var character = className[startIndex + i]; + buffer[destinationIndex++] = containsNestedTypeSeparator && character == '+' ? '.' : character; + } + + buffer[destinationIndex++] = '.'; + methodName.CopyTo(0, buffer, destinationIndex, methodName.Length); + + return new string(buffer, 0, totalLength); + } + finally + { + ArrayPool.Shared.Return(buffer); + } + } + + private readonly struct HandlerNameKey : IEquatable + { + private readonly string _name; + private readonly string _method; - if (className.IndexOf('+') >= 0) - className = className.Replace('+', '.'); + public HandlerNameKey(string name, string method) + { + _name = name; + _method = method; + } - return string.Concat(className, ".", requestHandler.Method.Name); + public bool Equals(HandlerNameKey other) + { + return ReferenceEquals(_name, other._name) && ReferenceEquals(_method, other._method) + || (string.Equals(_name, other._name, StringComparison.Ordinal) + && string.Equals(_method, other._method, StringComparison.Ordinal)); + } + + public override bool Equals(object? obj) + { + return obj is HandlerNameKey other && Equals(other); + } + + public override int GetHashCode() + { + return HashCode.Combine(StringComparer.Ordinal.GetHashCode(_name), StringComparer.Ordinal.GetHashCode(_method)); + } } + } From a79c4898eb695a9e2d07362c8bdafdac387ee35a Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 23:17:17 -0500 Subject: [PATCH 15/32] Optimize handler collision tracking (#64) --- src/GeneratedEndpoints/MinimalApiGenerator.cs | 38 ++++++++++++------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index fe72434..623fd19 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -240,9 +240,11 @@ private static ImmutableArray GetRequestHandlersWithNameCollisions(Immutabl var handlerCount = requestHandlers.Length; var nameToFirstIndex = new Dictionary(handlerCount); + var collisionFlags = ArrayPool.Shared.Rent(handlerCount); Array.Clear(collisionFlags, 0, handlerCount); - int[]? collidingIndices = null; + + int[]? collidingArray = null; var collidingCount = 0; try @@ -267,20 +269,19 @@ private static ImmutableArray GetRequestHandlersWithNameCollisions(Immutabl } } - if (collidingIndices is null || collidingCount == 0) + if (collidingCount == 0) return ImmutableArray.Empty; - Array.Sort(collidingIndices, 0, collidingCount); - var builder = ImmutableArray.CreateBuilder(collidingCount); - for (var i = 0; i < collidingCount; i++) - builder.Add(collidingIndices[i]); - return builder.MoveToImmutable(); + Array.Sort(collidingArray!, 0, collidingCount); + var sorted = new int[collidingCount]; + Array.Copy(collidingArray!, 0, sorted, 0, collidingCount); + return ImmutableArray.Create(sorted); } finally { - ArrayPool.Shared.Return(collisionFlags); - if (collidingIndices is not null) - ArrayPool.Shared.Return(collidingIndices); + ArrayPool.Shared.Return(collisionFlags, clearArray: false); + if (collidingArray is not null) + ArrayPool.Shared.Return(collidingArray, clearArray: false); } void MarkCollision(int handlerIndex) @@ -289,8 +290,8 @@ void MarkCollision(int handlerIndex) return; collisionFlags[handlerIndex] = true; - collidingIndices ??= ArrayPool.Shared.Rent(handlerCount); - collidingIndices[collidingCount++] = handlerIndex; + collidingArray ??= ArrayPool.Shared.Rent(handlerCount); + collidingArray[collidingCount++] = handlerIndex; } } @@ -334,11 +335,13 @@ private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestH { private readonly string _name; private readonly string _method; + private readonly int _hashCode; public HandlerNameKey(string name, string method) { _name = name; _method = method; + _hashCode = CombineHashCodes(StringComparer.Ordinal.GetHashCode(name), StringComparer.Ordinal.GetHashCode(method)); } public bool Equals(HandlerNameKey other) @@ -355,7 +358,16 @@ public override bool Equals(object? obj) public override int GetHashCode() { - return HashCode.Combine(StringComparer.Ordinal.GetHashCode(_name), StringComparer.Ordinal.GetHashCode(_method)); + return _hashCode; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int CombineHashCodes(int left, int right) + { + unchecked + { + return (left * 397) ^ right; + } } } From d4f96e286e40bf804033e9c669da79d97b34037b Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Sun, 16 Nov 2025 23:41:37 -0500 Subject: [PATCH 16/32] Update MinimalApiGenerator.cs --- src/GeneratedEndpoints/MinimalApiGenerator.cs | 30 ++++--------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 623fd19..9b9b2cc 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -1,4 +1,4 @@ -using System.Buffers; +using System.Buffers; using System.Collections.Immutable; using System.Runtime.CompilerServices; using GeneratedEndpoints.Common; @@ -116,30 +116,10 @@ private static (string HttpMethod, string Pattern, string? Name) GetRequestHandl { cancellationToken.ThrowIfCancellationRequested(); - var attributeName = attribute.AttributeClass?.Name ?? string.Empty; - - var httpMethod = HttpAttributeDefinitionsByName.TryGetValue(attributeName, out var definition) ? definition.Verb : string.Empty; - - var pattern = attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : null; - pattern ??= string.Empty; - - string? name = null; - var namedArguments = attribute.NamedArguments; - if (!namedArguments.IsDefaultOrEmpty) - { - for (var i = 0; i < namedArguments.Length; i++) - { - var namedArg = namedArguments[i]; - if (namedArg.Key != NameAttributeNamedParameter) - continue; - - if (namedArg.Value.Value is not string value) - break; - - name = string.IsNullOrWhiteSpace(value) ? null : value.Trim(); - break; - } - } + var attributeName = attribute.AttributeClass?.Name ?? ""; + var httpMethod = HttpAttributeDefinitionsByName.TryGetValue(attributeName, out var definition) ? definition.Verb : ""; + var pattern = attribute.GetConstructorStringValue() ?? ""; + var name = attribute.GetNamedStringValue(NameAttributeNamedParameter); return (httpMethod, pattern, name); } From 57a200d500a4d98a89e1a90d299d43d432c38a1f Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 00:07:01 -0500 Subject: [PATCH 17/32] Refactor. --- .../Common/EndpointConfiguration.cs | 1 - .../Common/EndpointConfigurationFactory.cs | 15 +----- ...terHelper.cs => MethodSymbolExtensions.cs} | 4 +- .../Common/RequestHandler.cs | 1 + .../Common/RequestHandlerClassCacheEntry.cs | 2 +- src/GeneratedEndpoints/MinimalApiGenerator.cs | 52 +++++++++---------- .../UseEndpointHandlersGenerator.cs | 27 +++++----- 7 files changed, 43 insertions(+), 59 deletions(-) rename src/GeneratedEndpoints/Common/{RequestHandlerParameterHelper.cs => MethodSymbolExtensions.cs} (96%) diff --git a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs index 2a79eaa..60c5bbc 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs @@ -2,7 +2,6 @@ namespace GeneratedEndpoints.Common; internal readonly record struct EndpointConfiguration { - public required string? Name { get; init; } public required string? DisplayName { get; init; } public required string? Summary { get; init; } public required string? Description { get; init; } diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index 9d7f2fa..f949d1e 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -1,4 +1,3 @@ -using System.Collections.Immutable; using System.Runtime.CompilerServices; using Microsoft.CodeAnalysis; using static GeneratedEndpoints.Common.Constants; @@ -9,14 +8,11 @@ internal static class EndpointConfigurationFactory { private static readonly ConditionalWeakTable GeneratedAttributeKindCache = new(); - public static EndpointConfiguration Create(ISymbol symbol, string? name) + public static EndpointConfiguration Create(ISymbol symbol) { var attributes = symbol.GetAttributes(); - if (symbol is IMethodSymbol) - name ??= RemoveAsyncSuffix(symbol.Name); - string? displayName = null; string? description = null; EquatableImmutableArray? tags = null; @@ -152,7 +148,6 @@ public static EndpointConfiguration Create(ISymbol symbol, string? name) return new EndpointConfiguration { - Name = name, DisplayName = displayName, Summary = summary, Description = description, @@ -214,14 +209,6 @@ private static RequestHandlerAttributeKind GetGeneratedAttributeKind(INamedTypeS return cacheEntry.Kind; } - private static string RemoveAsyncSuffix(string methodName) - { - if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) - return methodName[..^AsyncSuffix.Length]; - - return methodName; - } - private static EquatableImmutableArray? ToEquatableOrNull(List? values) { return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null; diff --git a/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs b/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs similarity index 96% rename from src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs rename to src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs index 7a4643d..dca5b1c 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerParameterHelper.cs +++ b/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs @@ -9,9 +9,9 @@ namespace GeneratedEndpoints.Common; // ReSharper disable LoopCanBeConvertedToQuery // Do not refactor, use for loop to avoid allocations. -internal static class RequestHandlerParameterHelper +internal static class MethodSymbolExtensions { - public static EquatableImmutableArray Build(IMethodSymbol methodSymbol, CancellationToken cancellationToken) + public static EquatableImmutableArray GetParameters(this IMethodSymbol methodSymbol, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); diff --git a/src/GeneratedEndpoints/Common/RequestHandler.cs b/src/GeneratedEndpoints/Common/RequestHandler.cs index eeb6736..69acdaa 100644 --- a/src/GeneratedEndpoints/Common/RequestHandler.cs +++ b/src/GeneratedEndpoints/Common/RequestHandler.cs @@ -5,5 +5,6 @@ internal readonly record struct RequestHandler( RequestHandlerMethod Method, string HttpMethod, string Pattern, + string? Name, EndpointConfiguration Configuration ); diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs index 6d8d368..ee0d1fe 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs @@ -28,7 +28,7 @@ public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, Compilation compilationCache.ServiceProviderSymbol, cancellationToken ); - var classConfiguration = EndpointConfigurationFactory.Create(classSymbol, null); + var classConfiguration = EndpointConfigurationFactory.Create(classSymbol); _value = new RequestHandlerClass(name, isStatic, configureMethodDetails.HasConfigureMethod, configureMethodDetails.ConfigureMethodAcceptsServiceProvider, classConfiguration diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 9b9b2cc..ae85cef 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -103,16 +103,16 @@ private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToke var requestHandlerMethod = GetRequestHandlerMethod(requestHandlerMethodSymbol, cancellationToken); - var (httpMethod, pattern, name) = GetRequestHandlerAttribute(attribute, cancellationToken); + var (httpMethod, pattern, name) = GetRequestHandlerAttribute(requestHandlerMethodSymbol, attribute, cancellationToken); - var methodConfiguration = EndpointConfigurationFactory.Create(requestHandlerMethodSymbol, name); + var methodConfiguration = EndpointConfigurationFactory.Create(requestHandlerMethodSymbol); - var requestHandler = new RequestHandler(requestHandlerClass.Value, requestHandlerMethod, httpMethod, pattern, methodConfiguration); + var requestHandler = new RequestHandler(requestHandlerClass.Value, requestHandlerMethod, httpMethod, pattern, name, methodConfiguration); return requestHandler; } - private static (string HttpMethod, string Pattern, string? Name) GetRequestHandlerAttribute(AttributeData attribute, CancellationToken cancellationToken) + private static (string HttpMethod, string Pattern, string? Name) GetRequestHandlerAttribute(IMethodSymbol methodSymbol, AttributeData attribute, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -120,17 +120,27 @@ private static (string HttpMethod, string Pattern, string? Name) GetRequestHandl var httpMethod = HttpAttributeDefinitionsByName.TryGetValue(attributeName, out var definition) ? definition.Verb : ""; var pattern = attribute.GetConstructorStringValue() ?? ""; var name = attribute.GetNamedStringValue(NameAttributeNamedParameter); + name ??= RemoveAsyncSuffix(methodSymbol.Name); + return (httpMethod, pattern, name); } + private static string RemoveAsyncSuffix(string methodName) + { + if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) + return methodName[..^AsyncSuffix.Length]; + + return methodName; + } + private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol methodSymbol, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); var name = methodSymbol.Name; var isStatic = methodSymbol.IsStatic; - var parameters = RequestHandlerParameterHelper.Build(methodSymbol, cancellationToken); + var parameters = methodSymbol.GetParameters(cancellationToken); var requestHandlerMethod = new RequestHandlerMethod(name, isStatic, parameters); @@ -181,7 +191,7 @@ private static ImmutableArray SortRequestHandlers(EquatableImmut if (count == 0) return ImmutableArray.Empty; if (count == 1) - return ImmutableArray.Create(requestHandlers[0]); + return [requestHandlers[0]]; var builder = ImmutableArray.CreateBuilder(count); builder.AddRange(requestHandlers); @@ -200,14 +210,11 @@ private static ImmutableArray EnsureUniqueEndpointNames(Immutabl { var index = collidingHandlers[i]; var handler = builder[index]; - var configuration = handler.Configuration with + var newHandler = handler with { Name = GetFullyQualifiedMethodDisplayName(handler), }; - builder[index] = handler with - { - Configuration = configuration, - }; + builder[index] = newHandler; } return builder.MoveToImmutable(); @@ -232,7 +239,7 @@ private static ImmutableArray GetRequestHandlersWithNameCollisions(Immutabl for (var index = 0; index < handlerCount; index++) { var handler = requestHandlers[index]; - var name = handler.Configuration.Name; + var name = handler.Name; if (string.IsNullOrEmpty(name)) continue; @@ -255,7 +262,7 @@ private static ImmutableArray GetRequestHandlersWithNameCollisions(Immutabl Array.Sort(collidingArray!, 0, collidingCount); var sorted = new int[collidingCount]; Array.Copy(collidingArray!, 0, sorted, 0, collidingCount); - return ImmutableArray.Create(sorted); + return [..sorted]; } finally { @@ -277,8 +284,8 @@ void MarkCollision(int handlerIndex) private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestHandler) { - var className = requestHandler.Class.Name ?? string.Empty; - var methodName = requestHandler.Method.Name ?? string.Empty; + var className = requestHandler.Class.Name; + var methodName = requestHandler.Method.Name; var startIndex = className.StartsWith(GlobalPrefix, StringComparison.Ordinal) ? GlobalPrefix.Length @@ -311,18 +318,11 @@ private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestH } } - private readonly struct HandlerNameKey : IEquatable + private readonly struct HandlerNameKey(string name, string method) : IEquatable { - private readonly string _name; - private readonly string _method; - private readonly int _hashCode; - - public HandlerNameKey(string name, string method) - { - _name = name; - _method = method; - _hashCode = CombineHashCodes(StringComparer.Ordinal.GetHashCode(name), StringComparer.Ordinal.GetHashCode(method)); - } + private readonly string _name = name; + private readonly string _method = method; + private readonly int _hashCode = CombineHashCodes(StringComparer.Ordinal.GetHashCode(name), StringComparer.Ordinal.GetHashCode(method)); public bool Equals(HandlerNameKey other) { diff --git a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs index c1c0715..3323b77 100644 --- a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs +++ b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs @@ -59,7 +59,7 @@ public static void GenerateSource(SourceProductionContext context, ImmutableArra source.Append(" = builder.MapGroup("); source.Append(groupedClass.Configuration.GroupPattern!.ToStringLiteral()); source.Append(')'); - AppendEndpointConfiguration(source, " ", groupedClass.Configuration, false); + AppendEndpointConfiguration(source, " ", groupedClass.Configuration); source.AppendLine(";"); } @@ -203,7 +203,15 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl if (requestHandler.Class.Configuration.GroupPattern is null) configuration = MergeEndpointConfigurations(requestHandler.Class.Configuration, configuration); - AppendEndpointConfiguration(source, continuationIndent, configuration, true); + if (!string.IsNullOrEmpty(requestHandler.Name)) + { + source.AppendLine(); + source.Append(continuationIndent); + source.Append(".WithName("); + source.Append(requestHandler.Name.ToStringLiteral()); + source.Append(')'); + } + AppendEndpointConfiguration(source, continuationIndent, configuration); if (wrapWithConfigure && configureAcceptsServiceProvider) { @@ -224,18 +232,9 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl } } - private static void AppendEndpointConfiguration(StringBuilder source, string indent, EndpointConfiguration configuration, bool includeNameAndDisplayName) + private static void AppendEndpointConfiguration(StringBuilder source, string indent, EndpointConfiguration configuration) { - if (includeNameAndDisplayName && !string.IsNullOrEmpty(configuration.Name)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithName("); - source.Append(configuration.Name.ToStringLiteral()); - source.Append(')'); - } - - if (includeNameAndDisplayName && !string.IsNullOrEmpty(configuration.DisplayName)) + if (!string.IsNullOrEmpty(configuration.DisplayName)) { source.AppendLine(); source.Append(indent); @@ -463,7 +462,6 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfiguration classConfiguration, EndpointConfiguration methodConfiguration) { - var name = methodConfiguration.Name ?? classConfiguration.Name; var displayName = methodConfiguration.DisplayName ?? classConfiguration.DisplayName; var summary = methodConfiguration.Summary ?? classConfiguration.Summary; var description = methodConfiguration.Description ?? classConfiguration.Description; @@ -507,7 +505,6 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu return new EndpointConfiguration { - Name = name, DisplayName = displayName, Summary = summary, Description = description, From 866c7e58e507ba35f949a8851388851345bd6d6b Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 00:09:30 -0500 Subject: [PATCH 18/32] Refactor. --- .../Common/HandlerNameKey.cs | 36 +++++++++++++++++++ src/GeneratedEndpoints/MinimalApiGenerator.cs | 34 ------------------ 2 files changed, 36 insertions(+), 34 deletions(-) create mode 100644 src/GeneratedEndpoints/Common/HandlerNameKey.cs diff --git a/src/GeneratedEndpoints/Common/HandlerNameKey.cs b/src/GeneratedEndpoints/Common/HandlerNameKey.cs new file mode 100644 index 0000000..7ce2e11 --- /dev/null +++ b/src/GeneratedEndpoints/Common/HandlerNameKey.cs @@ -0,0 +1,36 @@ +using System.Runtime.CompilerServices; + +namespace GeneratedEndpoints.Common; + +internal readonly struct HandlerNameKey(string name, string method) : IEquatable +{ + private readonly string _name = name; + private readonly string _method = method; + private readonly int _hashCode = CombineHashCodes(StringComparer.Ordinal.GetHashCode(name), StringComparer.Ordinal.GetHashCode(method)); + + public bool Equals(HandlerNameKey other) + { + return ReferenceEquals(_name, other._name) && ReferenceEquals(_method, other._method) + || (string.Equals(_name, other._name, StringComparison.Ordinal) + && string.Equals(_method, other._method, StringComparison.Ordinal)); + } + + public override bool Equals(object? obj) + { + return obj is HandlerNameKey other && Equals(other); + } + + public override int GetHashCode() + { + return _hashCode; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int CombineHashCodes(int left, int right) + { + unchecked + { + return (left * 397) ^ right; + } + } +} diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index ae85cef..b7edcff 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -317,38 +317,4 @@ private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestH ArrayPool.Shared.Return(buffer); } } - - private readonly struct HandlerNameKey(string name, string method) : IEquatable - { - private readonly string _name = name; - private readonly string _method = method; - private readonly int _hashCode = CombineHashCodes(StringComparer.Ordinal.GetHashCode(name), StringComparer.Ordinal.GetHashCode(method)); - - public bool Equals(HandlerNameKey other) - { - return ReferenceEquals(_name, other._name) && ReferenceEquals(_method, other._method) - || (string.Equals(_name, other._name, StringComparison.Ordinal) - && string.Equals(_method, other._method, StringComparison.Ordinal)); - } - - public override bool Equals(object? obj) - { - return obj is HandlerNameKey other && Equals(other); - } - - public override int GetHashCode() - { - return _hashCode; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int CombineHashCodes(int left, int right) - { - unchecked - { - return (left * 397) ^ right; - } - } - } - } From 4eb59675316fd78de7c65dbc229e98091b7ffb4d Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 00:24:06 -0500 Subject: [PATCH 19/32] Add endpoint name collision coverage (#65) --- .../Common/SourceFactory.cs | 27 ++++++++++ ...ollisions_AddEndpointHandlers.verified.txt | 23 +++++++++ ...ollisions_MapEndpointHandlers.verified.txt | 51 +++++++++++++++++++ .../IndividualTests.cs | 10 ++++ 4 files changed, 111 insertions(+) create mode 100644 tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_AddEndpointHandlers.verified.txt create mode 100644 tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_MapEndpointHandlers.verified.txt diff --git a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs index 2e25ff4..2b843d6 100644 --- a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs +++ b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs @@ -343,6 +343,33 @@ public static string BuildHttpMethodMatrixSource( return builder.ToString(); } + public static string BuildEndpointNameCollisionSource() + => """ + internal static class AlphaEndpoints + { + [MapGet("/alpha/collision")] public static Ok Collision() => TypedResults.Ok("alpha-collision"); + [MapGet("/alpha/unique")] public static Ok UniqueAlpha() => TypedResults.Ok("unique-alpha"); + } + + internal static class BetaEndpoints + { + [MapGet("/beta/unique")] public static Ok UniqueBeta() => TypedResults.Ok("unique-beta"); + [MapGet("/beta/collision")] public static Ok Collision() => TypedResults.Ok("beta-collision"); + } + + internal static class GammaEndpoints + { + [MapGet("/gamma/collision")] public static Ok Collision() => TypedResults.Ok("gamma-collision"); + [MapGet("/gamma/unique")] public static Ok UniqueGamma() => TypedResults.Ok("unique-gamma"); + } + + internal static class DeltaEndpoints + { + [MapGet("/delta/unique")] public static Ok UniqueDelta() => TypedResults.Ok("unique-delta"); + [MapGet("/delta/collision")] public static Ok Collision() => TypedResults.Ok("delta-collision"); + } + """; + public static string BuildContractsAndBindingSource( bool includeBindingNames, bool includeAsParameters, diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_AddEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_AddEndpointHandlers.verified.txt new file mode 100644 index 0000000..64b4329 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_AddEndpointHandlers.verified.txt @@ -0,0 +1,23 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointServicesExtensions +{ + internal static void AddEndpointHandlers(this IServiceCollection services) + { + } +} diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_MapEndpointHandlers.verified.txt new file mode 100644 index 0000000..8950dd4 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_MapEndpointHandlers.verified.txt @@ -0,0 +1,51 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointRouteBuilderExtensions +{ + internal static IEndpointRouteBuilder MapEndpointHandlers(this IEndpointRouteBuilder builder) + { + builder.MapGet("/alpha/collision", global::GeneratedEndpointsTests.AlphaEndpoints.Collision) + .WithName("GeneratedEndpointsTests.AlphaEndpoints.Collision"); + + builder.MapGet("/alpha/unique", global::GeneratedEndpointsTests.AlphaEndpoints.UniqueAlpha) + .WithName("UniqueAlpha"); + + builder.MapGet("/beta/collision", global::GeneratedEndpointsTests.BetaEndpoints.Collision) + .WithName("GeneratedEndpointsTests.BetaEndpoints.Collision"); + + builder.MapGet("/beta/unique", global::GeneratedEndpointsTests.BetaEndpoints.UniqueBeta) + .WithName("UniqueBeta"); + + builder.MapGet("/delta/collision", global::GeneratedEndpointsTests.DeltaEndpoints.Collision) + .WithName("GeneratedEndpointsTests.DeltaEndpoints.Collision"); + + builder.MapGet("/delta/unique", global::GeneratedEndpointsTests.DeltaEndpoints.UniqueDelta) + .WithName("UniqueDelta"); + + builder.MapGet("/gamma/collision", global::GeneratedEndpointsTests.GammaEndpoints.Collision) + .WithName("GeneratedEndpointsTests.GammaEndpoints.Collision"); + + builder.MapGet("/gamma/unique", global::GeneratedEndpointsTests.GammaEndpoints.UniqueGamma) + .WithName("UniqueGamma"); + + return builder; + } +} diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.cs b/tests/GeneratedEndpoints.Tests/IndividualTests.cs index fbeff64..ea30eb9 100644 --- a/tests/GeneratedEndpoints.Tests/IndividualTests.cs +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.cs @@ -327,6 +327,13 @@ public async Task MethodNameCollision() await VerifyIndividualAsync(source, nameof(MethodNameCollision)); } + [Fact] + public async Task MultipleEndpointNameCollisions() + { + var source = EndpointNameCollisionScenario(); + await VerifyIndividualAsync(source, nameof(MultipleEndpointNameCollisions)); + } + [Fact] public async Task BindingNames() { @@ -561,6 +568,9 @@ private static string HttpMethodScenario( includeConnect, includeMethodNameCollision); + private static string EndpointNameCollisionScenario() + => SourceFactory.BuildEndpointNameCollisionSource(); + private static string ContractScenario( bool includeBindingNames = false, bool includeAsParameters = false, From 8f7ca52b51720181b1784f7485fda7da7f45dae8 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 00:24:26 -0500 Subject: [PATCH 20/32] Removed compilation cache. --- .../Common/CompilationTypeCache.cs | 11 ------- .../Common/RequestHandlerClassCacheEntry.cs | 30 +++++-------------- src/GeneratedEndpoints/MinimalApiGenerator.cs | 13 ++------ 3 files changed, 11 insertions(+), 43 deletions(-) delete mode 100644 src/GeneratedEndpoints/Common/CompilationTypeCache.cs diff --git a/src/GeneratedEndpoints/Common/CompilationTypeCache.cs b/src/GeneratedEndpoints/Common/CompilationTypeCache.cs deleted file mode 100644 index 0567f3a..0000000 --- a/src/GeneratedEndpoints/Common/CompilationTypeCache.cs +++ /dev/null @@ -1,11 +0,0 @@ -using Microsoft.CodeAnalysis; - -namespace GeneratedEndpoints.Common; - -internal sealed class CompilationTypeCache(Compilation compilation) -{ - public INamedTypeSymbol? EndpointConventionBuilderSymbol { get; } = - compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Builder.IEndpointConventionBuilder"); - - public INamedTypeSymbol? ServiceProviderSymbol { get; } = compilation.GetTypeByMetadataName("System.IServiceProvider"); -} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs index ee0d1fe..645073b 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs @@ -10,7 +10,7 @@ internal sealed class RequestHandlerClassCacheEntry private RequestHandlerClass _value; private bool _initialized; - public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, CompilationTypeCache compilationCache, CancellationToken cancellationToken) + public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, CancellationToken cancellationToken) { if (_initialized) return _value; @@ -24,8 +24,7 @@ public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, Compilation var name = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var isStatic = classSymbol.IsStatic; - var configureMethodDetails = GetConfigureMethodDetails(classSymbol, compilationCache.EndpointConventionBuilderSymbol, - compilationCache.ServiceProviderSymbol, cancellationToken + var configureMethodDetails = GetConfigureMethodDetails(classSymbol, cancellationToken ); var classConfiguration = EndpointConfigurationFactory.Create(classSymbol); @@ -40,8 +39,6 @@ public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, Compilation private static ConfigureMethodDetails GetConfigureMethodDetails( INamedTypeSymbol classSymbol, - INamedTypeSymbol? endpointConventionBuilderSymbol, - INamedTypeSymbol? serviceProviderSymbol, CancellationToken cancellationToken ) { @@ -56,7 +53,7 @@ CancellationToken cancellationToken if (member is not IMethodSymbol methodSymbol) continue; - if (IsConfigureMethod(methodSymbol, endpointConventionBuilderSymbol, serviceProviderSymbol, out var methodAcceptsServiceProvider)) + if (IsConfigureMethod(methodSymbol, out var methodAcceptsServiceProvider)) { hasConfigureMethod = true; if (methodAcceptsServiceProvider) @@ -72,8 +69,6 @@ CancellationToken cancellationToken private static bool IsConfigureMethod( IMethodSymbol methodSymbol, - INamedTypeSymbol? endpointConventionBuilderSymbol, - INamedTypeSymbol? serviceProviderSymbol, out bool acceptsServiceProvider ) { @@ -97,7 +92,7 @@ out bool acceptsServiceProvider if (methodSymbol.Parameters.Length == 2) { var serviceProviderParameter = methodSymbol.Parameters[1]; - if (!IsServiceProviderParameter(serviceProviderParameter.Type, serviceProviderSymbol)) + if (!IsServiceProviderParameter(serviceProviderParameter.Type)) return false; acceptsServiceProvider = true; @@ -106,32 +101,23 @@ out bool acceptsServiceProvider if (!methodSymbol.ReturnsVoid) return false; - if (!HasEndpointConventionBuilderConstraint(builderTypeParameter, methodSymbol, endpointConventionBuilderSymbol)) + if (!HasEndpointConventionBuilderConstraint(builderTypeParameter, methodSymbol)) return false; return true; } - private static bool IsServiceProviderParameter(ITypeSymbol typeSymbol, INamedTypeSymbol? serviceProviderSymbol) + private static bool IsServiceProviderParameter(ITypeSymbol typeSymbol) { - if (serviceProviderSymbol is not null) - return SymbolEqualityComparer.Default.Equals(typeSymbol, serviceProviderSymbol); - return MatchesServiceProvider(typeSymbol); } private static bool HasEndpointConventionBuilderConstraint( ITypeParameterSymbol builderTypeParameter, - IMethodSymbol methodSymbol, - INamedTypeSymbol? endpointConventionBuilderSymbol + IMethodSymbol methodSymbol ) { - var symbolMatches = builderTypeParameter.ConstraintTypes.Any(constraint => - endpointConventionBuilderSymbol is not null - ? SymbolEqualityComparer.Default.Equals(constraint, endpointConventionBuilderSymbol) - : MatchesEndpointConventionBuilder(constraint) - ); - + var symbolMatches = builderTypeParameter.ConstraintTypes.Any(MatchesEndpointConventionBuilder); if (symbolMatches) return true; diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index b7edcff..1fc6652 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -15,7 +15,6 @@ namespace GeneratedEndpoints; [Generator] public sealed class MinimalApiGenerator : IIncrementalGenerator { - private static readonly ConditionalWeakTable CompilationTypeCaches = new(); private static readonly ConditionalWeakTable RequestHandlerClassCache = new(); public void Initialize(IncrementalGeneratorInitializationContext context) @@ -97,7 +96,7 @@ private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToke return null; var attribute = context.Attributes[0]; - var requestHandlerClass = GetRequestHandlerClass(requestHandlerMethodSymbol, context.SemanticModel.Compilation, cancellationToken); + var requestHandlerClass = GetRequestHandlerClass(requestHandlerMethodSymbol, cancellationToken); if (requestHandlerClass is null) return null; @@ -147,7 +146,7 @@ private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol method return requestHandlerMethod; } - private static RequestHandlerClass? GetRequestHandlerClass(IMethodSymbol methodSymbol, Compilation compilation, CancellationToken cancellationToken) + private static RequestHandlerClass? GetRequestHandlerClass(IMethodSymbol methodSymbol, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -155,18 +154,12 @@ private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol method if (classSymbol.TypeKind != TypeKind.Class) return null; - var typeCache = GetCompilationTypeCache(compilation); var cacheEntry = RequestHandlerClassCache.GetValue(classSymbol, static _ => new RequestHandlerClassCacheEntry()); - var requestHandlerClass = cacheEntry.GetOrCreate(classSymbol, typeCache, cancellationToken); + var requestHandlerClass = cacheEntry.GetOrCreate(classSymbol, cancellationToken); return requestHandlerClass; } - private static CompilationTypeCache GetCompilationTypeCache(Compilation compilation) - { - return CompilationTypeCaches.GetValue(compilation, static c => new CompilationTypeCache(c)); - } - private static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) { context.CancellationToken.ThrowIfCancellationRequested(); From 48ec71b0a2ac3e5363ec074dced715364f152657 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 00:28:01 -0500 Subject: [PATCH 21/32] Optimize endpoint name collision handling (#66) --- src/GeneratedEndpoints/MinimalApiGenerator.cs | 103 +++++++----------- 1 file changed, 40 insertions(+), 63 deletions(-) diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 1fc6652..3a84e11 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -193,86 +193,63 @@ private static ImmutableArray SortRequestHandlers(EquatableImmut } private static ImmutableArray EnsureUniqueEndpointNames(ImmutableArray requestHandlers) - { - var collidingHandlers = GetRequestHandlersWithNameCollisions(requestHandlers); - if (collidingHandlers.IsEmpty) - return requestHandlers; - - var builder = requestHandlers.ToBuilder(); - for (var i = 0; i < collidingHandlers.Length; i++) - { - var index = collidingHandlers[i]; - var handler = builder[index]; - var newHandler = handler with - { - Name = GetFullyQualifiedMethodDisplayName(handler), - }; - builder[index] = newHandler; - } - - return builder.MoveToImmutable(); - } - - private static ImmutableArray GetRequestHandlersWithNameCollisions(ImmutableArray requestHandlers) { if (requestHandlers.IsDefaultOrEmpty) - return ImmutableArray.Empty; + return requestHandlers; var handlerCount = requestHandlers.Length; - var nameToFirstIndex = new Dictionary(handlerCount); + var nameToCollision = new Dictionary(handlerCount); + ImmutableArray.Builder? builder = null; - var collisionFlags = ArrayPool.Shared.Rent(handlerCount); - Array.Clear(collisionFlags, 0, handlerCount); + for (var index = 0; index < handlerCount; index++) + { + var handler = requestHandlers[index]; + var name = handler.Name; + if (string.IsNullOrEmpty(name)) + continue; - int[]? collidingArray = null; - var collidingCount = 0; + var key = new HandlerNameKey(name!, handler.Method.Name); - try - { - for (var index = 0; index < handlerCount; index++) + if (!nameToCollision.TryGetValue(key, out var collision)) { - var handler = requestHandlers[index]; - var name = handler.Name; - if (string.IsNullOrEmpty(name)) - continue; - - var key = new HandlerNameKey(name!, handler.Method.Name); + nameToCollision.Add(key, new CollisionInfo(index, firstHandlerRenamed: false)); + continue; + } - if (nameToFirstIndex.TryGetValue(key, out var firstIndex)) - { - MarkCollision(firstIndex); - MarkCollision(index); - } - else + builder ??= requestHandlers.ToBuilder(); + if (!collision.FirstHandlerRenamed) + { + var firstHandler = builder[collision.FirstIndex]; + builder[collision.FirstIndex] = firstHandler with { - nameToFirstIndex.Add(key, index); - } + Name = GetFullyQualifiedMethodDisplayName(firstHandler), + }; + collision = collision.WithFirstHandlerRenamed(); } - if (collidingCount == 0) - return ImmutableArray.Empty; - - Array.Sort(collidingArray!, 0, collidingCount); - var sorted = new int[collidingCount]; - Array.Copy(collidingArray!, 0, sorted, 0, collidingCount); - return [..sorted]; + builder[index] = handler with + { + Name = GetFullyQualifiedMethodDisplayName(handler), + }; + nameToCollision[key] = collision; } - finally + + return builder is null ? requestHandlers : builder.MoveToImmutable(); + } + + private readonly struct CollisionInfo + { + public CollisionInfo(int firstIndex, bool firstHandlerRenamed) { - ArrayPool.Shared.Return(collisionFlags, clearArray: false); - if (collidingArray is not null) - ArrayPool.Shared.Return(collidingArray, clearArray: false); + FirstIndex = firstIndex; + FirstHandlerRenamed = firstHandlerRenamed; } - void MarkCollision(int handlerIndex) - { - if (collisionFlags[handlerIndex]) - return; + public int FirstIndex { get; } - collisionFlags[handlerIndex] = true; - collidingArray ??= ArrayPool.Shared.Rent(handlerCount); - collidingArray[collidingCount++] = handlerIndex; - } + public bool FirstHandlerRenamed { get; } + + public CollisionInfo WithFirstHandlerRenamed() => new(FirstIndex, true); } private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestHandler) From 911bdec521d4323c64660aab5861cc17bfc5e9a9 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 00:34:52 -0500 Subject: [PATCH 22/32] Refactor. --- .../Common/CollisionInfo.cs | 13 ++++ .../Common/RequestHandler.cs | 3 +- .../Common/RequestHandlerMethod.cs | 3 +- src/GeneratedEndpoints/MinimalApiGenerator.cs | 68 +++++++------------ 4 files changed, 42 insertions(+), 45 deletions(-) create mode 100644 src/GeneratedEndpoints/Common/CollisionInfo.cs diff --git a/src/GeneratedEndpoints/Common/CollisionInfo.cs b/src/GeneratedEndpoints/Common/CollisionInfo.cs new file mode 100644 index 0000000..192a0d5 --- /dev/null +++ b/src/GeneratedEndpoints/Common/CollisionInfo.cs @@ -0,0 +1,13 @@ +namespace GeneratedEndpoints.Common; + +internal readonly struct CollisionInfo(int firstIndex, bool firstHandlerRenamed) +{ + public int FirstIndex { get; } = firstIndex; + + public bool FirstHandlerRenamed { get; } = firstHandlerRenamed; + + public CollisionInfo WithFirstHandlerRenamed() + { + return new CollisionInfo(FirstIndex, true); + } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandler.cs b/src/GeneratedEndpoints/Common/RequestHandler.cs index 69acdaa..ca9954d 100644 --- a/src/GeneratedEndpoints/Common/RequestHandler.cs +++ b/src/GeneratedEndpoints/Common/RequestHandler.cs @@ -5,6 +5,5 @@ internal readonly record struct RequestHandler( RequestHandlerMethod Method, string HttpMethod, string Pattern, - string? Name, - EndpointConfiguration Configuration + string? Name ); diff --git a/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs index 55aeba8..54e2958 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs @@ -3,5 +3,6 @@ namespace GeneratedEndpoints.Common; internal readonly record struct RequestHandlerMethod( string Name, bool IsStatic, - EquatableImmutableArray Parameters + EquatableImmutableArray Parameters, + EndpointConfiguration Configuration ); diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 3a84e11..f32e1ef 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -92,26 +92,28 @@ private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToke { cancellationToken.ThrowIfCancellationRequested(); - if (context.TargetSymbol is not IMethodSymbol requestHandlerMethodSymbol) + if (context.TargetSymbol is not IMethodSymbol methodSymbol) return null; var attribute = context.Attributes[0]; - var requestHandlerClass = GetRequestHandlerClass(requestHandlerMethodSymbol, cancellationToken); + var requestHandlerClass = GetRequestHandlerClass(methodSymbol, cancellationToken); if (requestHandlerClass is null) return null; - var requestHandlerMethod = GetRequestHandlerMethod(requestHandlerMethodSymbol, cancellationToken); + var requestHandlerMethod = GetRequestHandlerMethod(methodSymbol, cancellationToken); - var (httpMethod, pattern, name) = GetRequestHandlerAttribute(requestHandlerMethodSymbol, attribute, cancellationToken); + var (httpMethod, pattern, name) = GetRequestHandlerAttribute(methodSymbol, attribute, cancellationToken); - var methodConfiguration = EndpointConfigurationFactory.Create(requestHandlerMethodSymbol); - - var requestHandler = new RequestHandler(requestHandlerClass.Value, requestHandlerMethod, httpMethod, pattern, name, methodConfiguration); + var requestHandler = new RequestHandler(requestHandlerClass.Value, requestHandlerMethod, httpMethod, pattern, name); return requestHandler; } - private static (string HttpMethod, string Pattern, string? Name) GetRequestHandlerAttribute(IMethodSymbol methodSymbol, AttributeData attribute, CancellationToken cancellationToken) + private static (string HttpMethod, string Pattern, string? Name) GetRequestHandlerAttribute( + IMethodSymbol methodSymbol, + AttributeData attribute, + CancellationToken cancellationToken + ) { cancellationToken.ThrowIfCancellationRequested(); @@ -121,7 +123,6 @@ private static (string HttpMethod, string Pattern, string? Name) GetRequestHandl var name = attribute.GetNamedStringValue(NameAttributeNamedParameter); name ??= RemoveAsyncSuffix(methodSymbol.Name); - return (httpMethod, pattern, name); } @@ -133,19 +134,6 @@ private static string RemoveAsyncSuffix(string methodName) return methodName; } - private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol methodSymbol, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - var name = methodSymbol.Name; - var isStatic = methodSymbol.IsStatic; - var parameters = methodSymbol.GetParameters(cancellationToken); - - var requestHandlerMethod = new RequestHandlerMethod(name, isStatic, parameters); - - return requestHandlerMethod; - } - private static RequestHandlerClass? GetRequestHandlerClass(IMethodSymbol methodSymbol, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -160,6 +148,19 @@ private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol method return requestHandlerClass; } + private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol methodSymbol, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + var name = methodSymbol.Name; + var isStatic = methodSymbol.IsStatic; + var parameters = methodSymbol.GetParameters(cancellationToken); + var configuration = EndpointConfigurationFactory.Create(methodSymbol); + var requestHandlerMethod = new RequestHandlerMethod(name, isStatic, parameters, configuration); + + return requestHandlerMethod; + } + private static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) { context.CancellationToken.ThrowIfCancellationRequested(); @@ -212,7 +213,7 @@ private static ImmutableArray EnsureUniqueEndpointNames(Immutabl if (!nameToCollision.TryGetValue(key, out var collision)) { - nameToCollision.Add(key, new CollisionInfo(index, firstHandlerRenamed: false)); + nameToCollision.Add(key, new CollisionInfo(index, false)); continue; } @@ -234,22 +235,7 @@ private static ImmutableArray EnsureUniqueEndpointNames(Immutabl nameToCollision[key] = collision; } - return builder is null ? requestHandlers : builder.MoveToImmutable(); - } - - private readonly struct CollisionInfo - { - public CollisionInfo(int firstIndex, bool firstHandlerRenamed) - { - FirstIndex = firstIndex; - FirstHandlerRenamed = firstHandlerRenamed; - } - - public int FirstIndex { get; } - - public bool FirstHandlerRenamed { get; } - - public CollisionInfo WithFirstHandlerRenamed() => new(FirstIndex, true); + return builder?.MoveToImmutable() ?? requestHandlers; } private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestHandler) @@ -257,9 +243,7 @@ private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestH var className = requestHandler.Class.Name; var methodName = requestHandler.Method.Name; - var startIndex = className.StartsWith(GlobalPrefix, StringComparison.Ordinal) - ? GlobalPrefix.Length - : 0; + var startIndex = className.StartsWith(GlobalPrefix, StringComparison.Ordinal) ? GlobalPrefix.Length : 0; var length = className.Length - startIndex; var containsNestedTypeSeparator = className.IndexOf('+', startIndex, length) >= 0; From 4d4825202ccfd534d781c46e1e11763df997c838 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 00:36:29 -0500 Subject: [PATCH 23/32] Fix. --- src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs index 3323b77..6f005bf 100644 --- a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs +++ b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs @@ -92,7 +92,7 @@ private static bool HasRateLimitedHandlers(ImmutableArray reques for (var index = 0; index < requestHandlers.Length; index++) { var handler = requestHandlers[index]; - if (handler.Configuration.RequireRateLimiting) + if (handler.Method.Configuration.RequireRateLimiting) return true; } @@ -199,7 +199,7 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl } source.Append(')'); - var configuration = requestHandler.Configuration; + var configuration = requestHandler.Method.Configuration; if (requestHandler.Class.Configuration.GroupPattern is null) configuration = MergeEndpointConfigurations(requestHandler.Class.Configuration, configuration); From 5f61074e3e786a2f4c2353d85c5e86eda84ad984 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 00:43:39 -0500 Subject: [PATCH 24/32] Simplify request handler normalization (#67) --- src/GeneratedEndpoints/MinimalApiGenerator.cs | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index f32e1ef..3e3901c 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -173,22 +173,26 @@ private static void GenerateSource(SourceProductionContext context, EquatableImm private static ImmutableArray NormalizeRequestHandlers(EquatableImmutableArray requestHandlers) { - var sorted = SortRequestHandlers(requestHandlers); - sorted = EnsureUniqueEndpointNames(sorted); + if (requestHandlers.Count == 0) + return ImmutableArray.Empty; + + if (requestHandlers.Count == 1) + return [requestHandlers[0]]; - return sorted; + var sorted = SortRequestHandlers(requestHandlers); + return EnsureUniqueEndpointNames(sorted); } private static ImmutableArray SortRequestHandlers(EquatableImmutableArray requestHandlers) { var count = requestHandlers.Count; - if (count == 0) - return ImmutableArray.Empty; - if (count == 1) - return [requestHandlers[0]]; var builder = ImmutableArray.CreateBuilder(count); - builder.AddRange(requestHandlers); + for (var index = 0; index < count; index++) + { + builder.Add(requestHandlers[index]); + } + builder.Sort(RequestHandlerComparer.Instance); return builder.MoveToImmutable(); } @@ -199,18 +203,18 @@ private static ImmutableArray EnsureUniqueEndpointNames(Immutabl return requestHandlers; var handlerCount = requestHandlers.Length; - var nameToCollision = new Dictionary(handlerCount); + Dictionary? nameToCollision = null; ImmutableArray.Builder? builder = null; for (var index = 0; index < handlerCount; index++) { - var handler = requestHandlers[index]; + var handler = builder is null ? requestHandlers[index] : builder[index]; var name = handler.Name; if (string.IsNullOrEmpty(name)) continue; + nameToCollision ??= new Dictionary(handlerCount); var key = new HandlerNameKey(name!, handler.Method.Name); - if (!nameToCollision.TryGetValue(key, out var collision)) { nameToCollision.Add(key, new CollisionInfo(index, false)); @@ -218,6 +222,7 @@ private static ImmutableArray EnsureUniqueEndpointNames(Immutabl } builder ??= requestHandlers.ToBuilder(); + handler = builder[index]; if (!collision.FirstHandlerRenamed) { var firstHandler = builder[collision.FirstIndex]; From d66c94163784319ede91eb612c2f34853b0f13f3 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 00:55:55 -0500 Subject: [PATCH 25/32] Use EquatableImmutableArray for request handlers (#68) --- .../AddEndpointHandlersGenerator.cs | 11 ++++----- .../Common/EquatableImmutableArray`1.cs | 5 ++++ src/GeneratedEndpoints/MinimalApiGenerator.cs | 23 ++++++++----------- .../UseEndpointHandlersGenerator.cs | 18 +++++++-------- 4 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs index b69c6fc..2da92c7 100644 --- a/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs +++ b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs @@ -1,5 +1,4 @@ -using System.Collections.Immutable; -using System.Text; +using System.Text; using GeneratedEndpoints.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Text; @@ -13,7 +12,7 @@ namespace GeneratedEndpoints; internal static class AddEndpointHandlersGenerator { - public static void GenerateSource(SourceProductionContext context, ImmutableArray requestHandlers) + public static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) { context.CancellationToken.ThrowIfCancellationRequested(); @@ -63,14 +62,14 @@ public static void GenerateSource(SourceProductionContext context, ImmutableArra context.AddSource(AddEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); } - private static List GetDistinctNonStaticClassNames(ImmutableArray requestHandlers) + private static List GetDistinctNonStaticClassNames(EquatableImmutableArray requestHandlers) { var classNames = new List(); - if (requestHandlers.IsDefaultOrEmpty) + if (requestHandlers.Count == 0) return classNames; var seen = new HashSet(StringComparer.Ordinal); - for (var index = 0; index < requestHandlers.Length; index++) + for (var index = 0; index < requestHandlers.Count; index++) { var requestHandler = requestHandlers[index]; if (requestHandler.Class.IsStatic) diff --git a/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs b/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs index be499bf..59434e3 100644 --- a/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs +++ b/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs @@ -21,6 +21,11 @@ namespace GeneratedEndpoints.Common; private ImmutableArray Array => _array ?? ImmutableArray.Empty; private readonly ImmutableArray? _array; + internal ImmutableArray AsImmutableArray() + { + return Array; + } + internal EquatableImmutableArray(ImmutableArray? array) { _array = array; diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 3e3901c..6c624cf 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -171,19 +171,16 @@ private static void GenerateSource(SourceProductionContext context, EquatableImm UseEndpointHandlersGenerator.GenerateSource(context, normalized); } - private static ImmutableArray NormalizeRequestHandlers(EquatableImmutableArray requestHandlers) + private static EquatableImmutableArray NormalizeRequestHandlers(EquatableImmutableArray requestHandlers) { - if (requestHandlers.Count == 0) - return ImmutableArray.Empty; - - if (requestHandlers.Count == 1) - return [requestHandlers[0]]; + if (requestHandlers.Count <= 1) + return requestHandlers; var sorted = SortRequestHandlers(requestHandlers); return EnsureUniqueEndpointNames(sorted); } - private static ImmutableArray SortRequestHandlers(EquatableImmutableArray requestHandlers) + private static EquatableImmutableArray SortRequestHandlers(EquatableImmutableArray requestHandlers) { var count = requestHandlers.Count; @@ -194,15 +191,15 @@ private static ImmutableArray SortRequestHandlers(EquatableImmut } builder.Sort(RequestHandlerComparer.Instance); - return builder.MoveToImmutable(); + return builder.ToEquatableImmutable(); } - private static ImmutableArray EnsureUniqueEndpointNames(ImmutableArray requestHandlers) + private static EquatableImmutableArray EnsureUniqueEndpointNames(EquatableImmutableArray requestHandlers) { - if (requestHandlers.IsDefaultOrEmpty) + if (requestHandlers.Count == 0) return requestHandlers; - var handlerCount = requestHandlers.Length; + var handlerCount = requestHandlers.Count; Dictionary? nameToCollision = null; ImmutableArray.Builder? builder = null; @@ -221,7 +218,7 @@ private static ImmutableArray EnsureUniqueEndpointNames(Immutabl continue; } - builder ??= requestHandlers.ToBuilder(); + builder ??= requestHandlers.AsImmutableArray().ToBuilder(); handler = builder[index]; if (!collision.FirstHandlerRenamed) { @@ -240,7 +237,7 @@ private static ImmutableArray EnsureUniqueEndpointNames(Immutabl nameToCollision[key] = collision; } - return builder?.MoveToImmutable() ?? requestHandlers; + return builder is null ? requestHandlers : builder.ToEquatableImmutable(); } private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestHandler) diff --git a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs index 6f005bf..9eeee1b 100644 --- a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs +++ b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs @@ -13,7 +13,7 @@ namespace GeneratedEndpoints; internal static class UseEndpointHandlersGenerator { - public static void GenerateSource(SourceProductionContext context, ImmutableArray requestHandlers) + public static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) { context.CancellationToken.ThrowIfCancellationRequested(); @@ -66,7 +66,7 @@ public static void GenerateSource(SourceProductionContext context, ImmutableArra if (groupedClasses.Count > 0) source.AppendLine(); - for (var index = 0; index < requestHandlers.Length; index++) + for (var index = 0; index < requestHandlers.Count; index++) { if (index > 0) source.AppendLine(); @@ -87,9 +87,9 @@ public static void GenerateSource(SourceProductionContext context, ImmutableArra context.AddSource(UseEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); } - private static bool HasRateLimitedHandlers(ImmutableArray requestHandlers) + private static bool HasRateLimitedHandlers(EquatableImmutableArray requestHandlers) { - for (var index = 0; index < requestHandlers.Length; index++) + for (var index = 0; index < requestHandlers.Count; index++) { var handler = requestHandlers[index]; if (handler.Method.Configuration.RequireRateLimiting) @@ -99,14 +99,14 @@ private static bool HasRateLimitedHandlers(ImmutableArray reques return false; } - private static List GetClassesWithMapGroups(ImmutableArray requestHandlers) + private static List GetClassesWithMapGroups(EquatableImmutableArray requestHandlers) { var groupedClasses = new List(); - if (requestHandlers.IsDefaultOrEmpty) + if (requestHandlers.Count == 0) return groupedClasses; var seen = new HashSet(StringComparer.Ordinal); - for (var index = 0; index < requestHandlers.Length; index++) + for (var index = 0; index < requestHandlers.Count; index++) { var handler = requestHandlers[index]; var handlerClass = handler.Class; @@ -603,12 +603,12 @@ private static EquatableImmutableArray MergeUnion(EquatableImmutableArra }; } - private static StringBuilder GetUseEndpointHandlersStringBuilder(ImmutableArray requestHandlers) + private static StringBuilder GetUseEndpointHandlersStringBuilder(EquatableImmutableArray requestHandlers) { const int baseSize = 4096; const int perHandler = 512; - var handlerCount = Math.Max(requestHandlers.Length, 0); + var handlerCount = Math.Max(requestHandlers.Count, 0); var estimate = baseSize + (long)perHandler * handlerCount; estimate = (long)(estimate * 1.10); From 7821816f0581ee15268a2cf004752a40825c7f9e Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 01:04:44 -0500 Subject: [PATCH 26/32] Refactored. --- .../Common/EquatableImmutableArray`1.cs | 24 +++++++++++++++++++ src/GeneratedEndpoints/MinimalApiGenerator.cs | 20 ++++------------ 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs b/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs index 59434e3..04407e9 100644 --- a/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs +++ b/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs @@ -1,5 +1,6 @@ using System.Collections; using System.Collections.Immutable; +using System.Runtime.InteropServices; namespace GeneratedEndpoints.Common; @@ -31,6 +32,29 @@ internal EquatableImmutableArray(ImmutableArray? array) _array = array; } + /// + /// Sorts the underlying array in place using the specified comparer. + /// WARNING: This mutates the underlying storage of the ImmutableArray and + /// must only be used when you can guarantee no other code observes it as immutable. + /// + internal EquatableImmutableArray SortInPlace(IComparer? comparer = null) + { + if (_array is null) + return this; + + var array = _array.Value; + if (array.Length <= 1) + return this; + + comparer ??= Comparer.Default; + + var raw = ImmutableCollectionsMarshal.AsArray(array); + if (raw is not null) + System.Array.Sort(raw, comparer); + + return this; + } + /// public bool Equals(EquatableImmutableArray other) { diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 6c624cf..0a150e9 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -176,22 +176,10 @@ private static EquatableImmutableArray NormalizeRequestHandlers( if (requestHandlers.Count <= 1) return requestHandlers; - var sorted = SortRequestHandlers(requestHandlers); - return EnsureUniqueEndpointNames(sorted); - } - - private static EquatableImmutableArray SortRequestHandlers(EquatableImmutableArray requestHandlers) - { - var count = requestHandlers.Count; - - var builder = ImmutableArray.CreateBuilder(count); - for (var index = 0; index < count; index++) - { - builder.Add(requestHandlers[index]); - } + var sorted = requestHandlers.SortInPlace(RequestHandlerComparer.Instance); + var unique = EnsureUniqueEndpointNames(sorted); - builder.Sort(RequestHandlerComparer.Instance); - return builder.ToEquatableImmutable(); + return unique; } private static EquatableImmutableArray EnsureUniqueEndpointNames(EquatableImmutableArray requestHandlers) @@ -237,7 +225,7 @@ private static EquatableImmutableArray EnsureUniqueEndpointNames nameToCollision[key] = collision; } - return builder is null ? requestHandlers : builder.ToEquatableImmutable(); + return builder?.ToEquatableImmutable() ?? requestHandlers; } private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestHandler) From 2fe6fb6cb41a23a42a4f13e18a15f862f551e130 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 01:37:55 -0500 Subject: [PATCH 27/32] Refactored. --- src/GeneratedEndpoints/Common/Constants.cs | 3 +- .../Common/RequestHandler.cs | 39 +++++++++++--- src/GeneratedEndpoints/MinimalApiGenerator.cs | 51 +++++-------------- 3 files changed, 46 insertions(+), 47 deletions(-) diff --git a/src/GeneratedEndpoints/Common/Constants.cs b/src/GeneratedEndpoints/Common/Constants.cs index 414a6a8..05235a2 100644 --- a/src/GeneratedEndpoints/Common/Constants.cs +++ b/src/GeneratedEndpoints/Common/Constants.cs @@ -102,8 +102,9 @@ internal static partial class Constants internal const string ConfigureMethodName = "Configure"; internal const string AsyncSuffix = "Async"; - internal const string GlobalPrefix = "global::"; internal const string ApplicationJsonContentType = "application/json"; + internal const string GlobalPrefix = "global::"; + internal const string Dot = "."; internal static readonly string[] AttributesNamespaceParts = AttributesNamespace.Split('.'); internal static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"]; diff --git a/src/GeneratedEndpoints/Common/RequestHandler.cs b/src/GeneratedEndpoints/Common/RequestHandler.cs index ca9954d..df42681 100644 --- a/src/GeneratedEndpoints/Common/RequestHandler.cs +++ b/src/GeneratedEndpoints/Common/RequestHandler.cs @@ -1,9 +1,34 @@ +using static GeneratedEndpoints.Common.Constants; + namespace GeneratedEndpoints.Common; -internal readonly record struct RequestHandler( - RequestHandlerClass Class, - RequestHandlerMethod Method, - string HttpMethod, - string Pattern, - string? Name -); +internal readonly record struct RequestHandler +{ + public required RequestHandlerClass Class { get; init; } + public required RequestHandlerMethod Method { get; init; } + public required string HttpMethod { get; init; } + public required string Pattern { get; init; } + public required string? Name { get; init; } + + public string GetFullyQualifiedMethodDisplayName() + { + ReadOnlySpan className = Class.Name; + ReadOnlySpan methodName = Method.Name; + + if (className.StartsWith(GlobalPrefix)) + className = className[GlobalPrefix.Length..]; + + var classLen = className.Length; + var methodLen = methodName.Length; + var total = classLen + 1 + methodLen; + + Span buffer = stackalloc char[total]; + className.CopyTo(buffer); + + buffer[classLen] = '.'; + + methodName.CopyTo(buffer[(classLen + 1)..]); + + return buffer.ToString(); + } +} diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 0a150e9..8a2e1e4 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -1,4 +1,3 @@ -using System.Buffers; using System.Collections.Immutable; using System.Runtime.CompilerServices; using GeneratedEndpoints.Common; @@ -104,7 +103,14 @@ private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToke var (httpMethod, pattern, name) = GetRequestHandlerAttribute(methodSymbol, attribute, cancellationToken); - var requestHandler = new RequestHandler(requestHandlerClass.Value, requestHandlerMethod, httpMethod, pattern, name); + var requestHandler = new RequestHandler + { + Class = requestHandlerClass.Value, + Method = requestHandlerMethod, + HttpMethod = httpMethod, + Pattern = pattern, + Name = name, + }; return requestHandler; } @@ -206,59 +212,26 @@ private static EquatableImmutableArray EnsureUniqueEndpointNames continue; } - builder ??= requestHandlers.AsImmutableArray().ToBuilder(); + builder ??= requestHandlers.AsImmutableArray() + .ToBuilder(); handler = builder[index]; if (!collision.FirstHandlerRenamed) { var firstHandler = builder[collision.FirstIndex]; builder[collision.FirstIndex] = firstHandler with { - Name = GetFullyQualifiedMethodDisplayName(firstHandler), + Name = firstHandler.GetFullyQualifiedMethodDisplayName(), }; collision = collision.WithFirstHandlerRenamed(); } builder[index] = handler with { - Name = GetFullyQualifiedMethodDisplayName(handler), + Name = handler.GetFullyQualifiedMethodDisplayName(), }; nameToCollision[key] = collision; } return builder?.ToEquatableImmutable() ?? requestHandlers; } - - private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestHandler) - { - var className = requestHandler.Class.Name; - var methodName = requestHandler.Method.Name; - - var startIndex = className.StartsWith(GlobalPrefix, StringComparison.Ordinal) ? GlobalPrefix.Length : 0; - var length = className.Length - startIndex; - var containsNestedTypeSeparator = className.IndexOf('+', startIndex, length) >= 0; - - if (length == 0) - return string.Concat(".", methodName); - - var totalLength = length + 1 + methodName.Length; - var buffer = ArrayPool.Shared.Rent(totalLength); - try - { - var destinationIndex = 0; - for (var i = 0; i < length; i++) - { - var character = className[startIndex + i]; - buffer[destinationIndex++] = containsNestedTypeSeparator && character == '+' ? '.' : character; - } - - buffer[destinationIndex++] = '.'; - methodName.CopyTo(0, buffer, destinationIndex, methodName.Length); - - return new string(buffer, 0, totalLength); - } - finally - { - ArrayPool.Shared.Return(buffer); - } - } } From 9541c72c1a43247391093a4e94fab488ab18ccf8 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 01:56:22 -0500 Subject: [PATCH 28/32] Streamline endpoint name collision resolution (#69) --- .../Common/CollisionInfo.cs | 13 ---- .../Common/HandlerNameKey.cs | 36 ---------- .../Common/RequestHandler.cs | 4 +- src/GeneratedEndpoints/MinimalApiGenerator.cs | 65 ++++++++++--------- 4 files changed, 35 insertions(+), 83 deletions(-) delete mode 100644 src/GeneratedEndpoints/Common/CollisionInfo.cs delete mode 100644 src/GeneratedEndpoints/Common/HandlerNameKey.cs diff --git a/src/GeneratedEndpoints/Common/CollisionInfo.cs b/src/GeneratedEndpoints/Common/CollisionInfo.cs deleted file mode 100644 index 192a0d5..0000000 --- a/src/GeneratedEndpoints/Common/CollisionInfo.cs +++ /dev/null @@ -1,13 +0,0 @@ -namespace GeneratedEndpoints.Common; - -internal readonly struct CollisionInfo(int firstIndex, bool firstHandlerRenamed) -{ - public int FirstIndex { get; } = firstIndex; - - public bool FirstHandlerRenamed { get; } = firstHandlerRenamed; - - public CollisionInfo WithFirstHandlerRenamed() - { - return new CollisionInfo(FirstIndex, true); - } -} diff --git a/src/GeneratedEndpoints/Common/HandlerNameKey.cs b/src/GeneratedEndpoints/Common/HandlerNameKey.cs deleted file mode 100644 index 7ce2e11..0000000 --- a/src/GeneratedEndpoints/Common/HandlerNameKey.cs +++ /dev/null @@ -1,36 +0,0 @@ -using System.Runtime.CompilerServices; - -namespace GeneratedEndpoints.Common; - -internal readonly struct HandlerNameKey(string name, string method) : IEquatable -{ - private readonly string _name = name; - private readonly string _method = method; - private readonly int _hashCode = CombineHashCodes(StringComparer.Ordinal.GetHashCode(name), StringComparer.Ordinal.GetHashCode(method)); - - public bool Equals(HandlerNameKey other) - { - return ReferenceEquals(_name, other._name) && ReferenceEquals(_method, other._method) - || (string.Equals(_name, other._name, StringComparison.Ordinal) - && string.Equals(_method, other._method, StringComparison.Ordinal)); - } - - public override bool Equals(object? obj) - { - return obj is HandlerNameKey other && Equals(other); - } - - public override int GetHashCode() - { - return _hashCode; - } - - [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static int CombineHashCodes(int left, int right) - { - unchecked - { - return (left * 397) ^ right; - } - } -} diff --git a/src/GeneratedEndpoints/Common/RequestHandler.cs b/src/GeneratedEndpoints/Common/RequestHandler.cs index df42681..674eb77 100644 --- a/src/GeneratedEndpoints/Common/RequestHandler.cs +++ b/src/GeneratedEndpoints/Common/RequestHandler.cs @@ -2,13 +2,13 @@ namespace GeneratedEndpoints.Common; -internal readonly record struct RequestHandler +internal record struct RequestHandler { public required RequestHandlerClass Class { get; init; } public required RequestHandlerMethod Method { get; init; } public required string HttpMethod { get; init; } public required string Pattern { get; init; } - public required string? Name { get; init; } + public required string? Name { get; set; } public string GetFullyQualifiedMethodDisplayName() { diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 8a2e1e4..6ff5c0e 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -1,5 +1,6 @@ using System.Collections.Immutable; using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; using GeneratedEndpoints.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -189,49 +190,49 @@ private static EquatableImmutableArray NormalizeRequestHandlers( } private static EquatableImmutableArray EnsureUniqueEndpointNames(EquatableImmutableArray requestHandlers) + { + ResolveEndpointNameCollisions(requestHandlers); + return requestHandlers; + } + + private static void ResolveEndpointNameCollisions(EquatableImmutableArray requestHandlers) { if (requestHandlers.Count == 0) - return requestHandlers; + return; - var handlerCount = requestHandlers.Count; - Dictionary? nameToCollision = null; - ImmutableArray.Builder? builder = null; + var handlers = requestHandlers.AsImmutableArray(); + var raw = ImmutableCollectionsMarshal.AsArray(handlers); + if (raw is null) + return; - for (var index = 0; index < handlerCount; index++) + var span = raw.AsSpan(); + + for (var outerIndex = 0; outerIndex < span.Length - 1; outerIndex++) { - var handler = builder is null ? requestHandlers[index] : builder[index]; - var name = handler.Name; - if (string.IsNullOrEmpty(name)) + ref var outer = ref span[outerIndex]; + var outerName = outer.Name; + if (string.IsNullOrEmpty(outerName)) continue; - nameToCollision ??= new Dictionary(handlerCount); - var key = new HandlerNameKey(name!, handler.Method.Name); - if (!nameToCollision.TryGetValue(key, out var collision)) + var collisionHandled = false; + for (var innerIndex = outerIndex + 1; innerIndex < span.Length; innerIndex++) { - nameToCollision.Add(key, new CollisionInfo(index, false)); - continue; - } + ref var inner = ref span[innerIndex]; + var innerName = inner.Name; + if (innerName is null) + continue; - builder ??= requestHandlers.AsImmutableArray() - .ToBuilder(); - handler = builder[index]; - if (!collision.FirstHandlerRenamed) - { - var firstHandler = builder[collision.FirstIndex]; - builder[collision.FirstIndex] = firstHandler with + if (!string.Equals(outerName, innerName, StringComparison.Ordinal)) + continue; + + if (!collisionHandled) { - Name = firstHandler.GetFullyQualifiedMethodDisplayName(), - }; - collision = collision.WithFirstHandlerRenamed(); - } + outer.Name = outer.GetFullyQualifiedMethodDisplayName(); + collisionHandled = true; + } - builder[index] = handler with - { - Name = handler.GetFullyQualifiedMethodDisplayName(), - }; - nameToCollision[key] = collision; + inner.Name = inner.GetFullyQualifiedMethodDisplayName(); + } } - - return builder?.ToEquatableImmutable() ?? requestHandlers; } } From b623d5de53043f4625a6b2a3bbb86b63865673bf Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 02:11:17 -0500 Subject: [PATCH 29/32] Refactor. --- .../Common/EquatableImmutableArray`1.cs | 25 ++++++++++--------- .../Common/RequestHandler.cs | 12 ++++++--- src/GeneratedEndpoints/MinimalApiGenerator.cs | 21 ++++------------ 3 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs b/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs index 04407e9..9270424 100644 --- a/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs +++ b/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs @@ -22,37 +22,38 @@ namespace GeneratedEndpoints.Common; private ImmutableArray Array => _array ?? ImmutableArray.Empty; private readonly ImmutableArray? _array; - internal ImmutableArray AsImmutableArray() - { - return Array; - } - internal EquatableImmutableArray(ImmutableArray? array) { _array = array; } /// - /// Sorts the underlying array in place using the specified comparer. - /// WARNING: This mutates the underlying storage of the ImmutableArray and + /// Gets the underlying array, or if the array is empty. WARNING: This returns the underlying storage of the ImmutableArray and /// must only be used when you can guarantee no other code observes it as immutable. /// - internal EquatableImmutableArray SortInPlace(IComparer? comparer = null) + internal T[]? AsArray() + { + return ImmutableCollectionsMarshal.AsArray(Array); + } + + /// + /// Sorts the underlying array in place using the specified comparer. WARNING: This mutates the underlying storage of the ImmutableArray and must only be + /// used when you can guarantee no other code observes it as immutable. + /// + internal void SortInPlace(IComparer? comparer = null) { if (_array is null) - return this; + return; var array = _array.Value; if (array.Length <= 1) - return this; + return; comparer ??= Comparer.Default; var raw = ImmutableCollectionsMarshal.AsArray(array); if (raw is not null) System.Array.Sort(raw, comparer); - - return this; } /// diff --git a/src/GeneratedEndpoints/Common/RequestHandler.cs b/src/GeneratedEndpoints/Common/RequestHandler.cs index 674eb77..6017b77 100644 --- a/src/GeneratedEndpoints/Common/RequestHandler.cs +++ b/src/GeneratedEndpoints/Common/RequestHandler.cs @@ -4,13 +4,19 @@ namespace GeneratedEndpoints.Common; internal record struct RequestHandler { + private string? _name; public required RequestHandlerClass Class { get; init; } public required RequestHandlerMethod Method { get; init; } public required string HttpMethod { get; init; } public required string Pattern { get; init; } - public required string? Name { get; set; } - public string GetFullyQualifiedMethodDisplayName() + public required string? Name + { + readonly get => _name; + init => _name = value; + } + + public void SetFullyQualifiedName() { ReadOnlySpan className = Class.Name; ReadOnlySpan methodName = Method.Name; @@ -29,6 +35,6 @@ public string GetFullyQualifiedMethodDisplayName() methodName.CopyTo(buffer[(classLen + 1)..]); - return buffer.ToString(); + _name = buffer.ToString(); } } diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 6ff5c0e..432eb02 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -1,6 +1,5 @@ using System.Collections.Immutable; using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; using GeneratedEndpoints.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -183,25 +182,15 @@ private static EquatableImmutableArray NormalizeRequestHandlers( if (requestHandlers.Count <= 1) return requestHandlers; - var sorted = requestHandlers.SortInPlace(RequestHandlerComparer.Instance); - var unique = EnsureUniqueEndpointNames(sorted); - - return unique; - } - - private static EquatableImmutableArray EnsureUniqueEndpointNames(EquatableImmutableArray requestHandlers) - { + requestHandlers.SortInPlace(RequestHandlerComparer.Instance); ResolveEndpointNameCollisions(requestHandlers); + return requestHandlers; } private static void ResolveEndpointNameCollisions(EquatableImmutableArray requestHandlers) { - if (requestHandlers.Count == 0) - return; - - var handlers = requestHandlers.AsImmutableArray(); - var raw = ImmutableCollectionsMarshal.AsArray(handlers); + var raw = requestHandlers.AsArray(); if (raw is null) return; @@ -227,11 +216,11 @@ private static void ResolveEndpointNameCollisions(EquatableImmutableArray Date: Mon, 17 Nov 2025 02:28:28 -0500 Subject: [PATCH 30/32] Optimize endpoint name collision handling (#70) --- src/GeneratedEndpoints/MinimalApiGenerator.cs | 39 +++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 432eb02..3cd6dd3 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -1,3 +1,4 @@ +using System.Collections.Generic; using System.Collections.Immutable; using System.Runtime.CompilerServices; using GeneratedEndpoints.Common; @@ -195,33 +196,31 @@ private static void ResolveEndpointNameCollisions(EquatableImmutableArray(span.Length, StringComparer.Ordinal); - for (var outerIndex = 0; outerIndex < span.Length - 1; outerIndex++) + for (var index = 0; index < span.Length; index++) { - ref var outer = ref span[outerIndex]; - var outerName = outer.Name; - if (string.IsNullOrEmpty(outerName)) + ref var handler = ref span[index]; + var name = handler.Name; + if (string.IsNullOrEmpty(name)) continue; + var nonEmptyName = name!; - var collisionHandled = false; - for (var innerIndex = outerIndex + 1; innerIndex < span.Length; innerIndex++) + if (!seen.TryGetValue(nonEmptyName, out var state)) { - ref var inner = ref span[innerIndex]; - var innerName = inner.Name; - if (innerName is null) - continue; - - if (!string.Equals(outerName, innerName, StringComparison.Ordinal)) - continue; - - if (!collisionHandled) - { - outer.SetFullyQualifiedName(); - collisionHandled = true; - } + seen.Add(nonEmptyName, index); + continue; + } - inner.SetFullyQualifiedName(); + var firstIndex = state >= 0 ? state : ~state; + if (state >= 0) + { + ref var firstHandler = ref span[firstIndex]; + firstHandler.SetFullyQualifiedName(); + seen[nonEmptyName] = ~firstIndex; } + + handler.SetFullyQualifiedName(); } } } From 5ebcee9d4cee6e0abc3242cc6255400930f11b04 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 02:46:02 -0500 Subject: [PATCH 31/32] Cleanup. --- .../UseEndpointHandlersGenerator.cs | 52 +++++++++++-------- ...08C7DE832_MapEndpointHandlers.verified.txt | 4 +- ...A05F2C177_MapEndpointHandlers.verified.txt | 4 +- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs index 9eeee1b..93ba94e 100644 --- a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs +++ b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs @@ -92,7 +92,7 @@ private static bool HasRateLimitedHandlers(EquatableImmutableArray() - .AddEndpointFilter(); + .AddEndpointFilter() + .AddEndpointFilter(); return builder; } diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_401A05F2C177_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_401A05F2C177_MapEndpointHandlers.verified.txt index d1abab8..d95bea0 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_401A05F2C177_MapEndpointHandlers.verified.txt +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_401A05F2C177_MapEndpointHandlers.verified.txt @@ -24,8 +24,8 @@ internal static class EndpointRouteBuilderExtensions { builder.MapGet("/configure-filters", global::GeneratedEndpointsTests.ConfigureFilterEndpoints.Handle) .WithName("Handle") - .AddEndpointFilter() - .AddEndpointFilter(); + .AddEndpointFilter() + .AddEndpointFilter(); return builder; } From c1c89c1e52e0a03f14768c67bd7d30c3b8072e03 Mon Sep 17 00:00:00 2001 From: Jean-Sebastien Carle <29762210+jscarle@users.noreply.github.com> Date: Mon, 17 Nov 2025 02:46:42 -0500 Subject: [PATCH 32/32] Reformat. --- README.md | 76 +- .../Common/AttributeDataExtensions.cs | 2 + .../Common/ConfigureMethodDetails.cs | 5 +- .../Common/Constants.GeneratedSources.cs | 966 +++++++++--------- src/GeneratedEndpoints/Common/Constants.cs | 43 +- .../Common/EndpointConfigurationFactory.cs | 13 +- .../Common/HttpAttributeDefinition.cs | 8 +- .../IncrementalValueProviderExtensions.cs | 102 +- .../Common/MethodSymbolExtensions.cs | 3 +- .../Common/ProducesProblemMetadata.cs | 6 +- .../ProducesValidationProblemMetadata.cs | 6 +- .../Common/RequestHandler.cs | 3 +- .../Common/RequestHandlerAttributeKind.cs | 2 +- .../Common/RequestHandlerClassCacheEntry.cs | 18 +- .../GeneratedEndpoints.csproj | 10 +- src/GeneratedEndpoints/MinimalApiGenerator.cs | 1 - .../GeneratedEndpoints.Tests.Lab.csproj | 8 +- .../AttributeGenerationTests.cs | 120 ++- .../Common/ScenarioNamer.cs | 4 +- .../Common/SourceFactory.cs | 181 +--- .../Common/TestHelpers.cs | 3 +- .../GeneratedEndpoints.Tests.csproj | 30 +- .../GeneratedSourceTests.cs | 209 ++-- .../IndividualTests.cs | 205 ++-- 24 files changed, 938 insertions(+), 1086 deletions(-) diff --git a/README.md b/README.md index 1f6aa08..31fc6f4 100644 --- a/README.md +++ b/README.md @@ -140,41 +140,41 @@ applied to all endpoints within the class. ## Attribute Reference -| Definition | Usage | Description | -| --- | --- | --- | -| `[Accepts(string contentType = "application/json", params string[] additionalContentTypes, RequestType = null, IsOptional = false)]` | Method | Declares the accepted request body CLR type, optional status, and list of content types for the handler. | -| `[Accepts(string contentType = "application/json", params string[] additionalContentTypes, IsOptional = false)]` | Method | Generic shortcut for specifying the request type and accepted content types for the handler. | -| `[AllowAnonymous]` | Class or Method | Allows the annotated endpoint or class to bypass authorization requirements. | -| `[Description(string description)]` | Method | Sets the OpenAPI description metadata for the generated endpoint. | -| `[DisableAntiforgery]` | Class or Method | Disables antiforgery protection for the annotated endpoint(s). | -| `[DisableRequestTimeout]` | Class or Method | Disables request timeout enforcement for the annotated endpoint(s). | -| `[DisableValidation]` | Class or Method | Disables automatic request validation (when supported) for the annotated endpoint(s). | -| `[DisplayName(string displayName)]` | Method | Overrides the endpoint display name used in diagnostics and metadata. | -| `[EndpointFilter(Type filterType)]` | Class or Method | Adds the specified endpoint filter type to the handler pipeline. | -| `[EndpointFilter]` | Class or Method | Generic form for registering an endpoint filter type on the handler pipeline. | -| `[ExcludeFromDescription]` | Class or Method | Hides the endpoint or class from generated API descriptions (e.g., OpenAPI). | -| `[MapConnect(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP CONNECT endpoint using the supplied route pattern and optional name. | -| `[MapDelete(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP DELETE endpoint using the supplied route pattern and optional name. | -| `[MapFallback(string pattern = "", Name = null)]` | Method | Maps the method as the fallback endpoint invoked when no other route matches. | -| `[MapGroup(string pattern, Name = null)]` | Class | Assigns a route group pattern and optional endpoint group name to every handler in the class. | -| `[MapGet(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP GET endpoint using the supplied route pattern and optional name. | -| `[MapHead(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP HEAD endpoint using the supplied route pattern and optional name. | -| `[MapOptions(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP OPTIONS endpoint using the supplied route pattern and optional name. | -| `[MapPatch(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP PATCH endpoint using the supplied route pattern and optional name. | -| `[MapPost(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP POST endpoint using the supplied route pattern and optional name. | -| `[MapPut(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP PUT endpoint using the supplied route pattern and optional name. | -| `[MapQuery(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP QUERY endpoint using the supplied route pattern and optional name. | -| `[MapTrace(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP TRACE endpoint using the supplied route pattern and optional name. | -| `[Order(int order)]` | Method | Controls the order in which endpoint conventions are applied to the handler. | -| `[ProducesProblem(int statusCode = StatusCodes.Status500InternalServerError, string? contentType = null, params string[] additionalContentTypes)]` | Method | Declares that the endpoint emits a problem details payload for the given status code and content types. | -| `[ProducesResponse(int statusCode = StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes, ResponseType = null)]` | Method | Declares response metadata for the handler, including status code, optional CLR type, and content types. | -| `[ProducesResponse(int statusCode = StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes)]` | Method | Generic shorthand for declaring the CLR response type along with status code and content types. | -| `[ProducesValidationProblem(int statusCode = StatusCodes.Status400BadRequest, string? contentType = null, params string[] additionalContentTypes)]` | Method | Declares that the endpoint returns validation problem details for the specified status code and content types. | -| `[RequestTimeout(string? policyName = null)]` | Class or Method | Applies the default or a named request-timeout policy to the handler(s). | -| `[RequireAuthorization(params string[] policies)]` | Class or Method | Enforces authorization on the handler(s), optionally scoping access to specific policies. | -| `[RequireCors(string? policyName = null)]` | Class or Method | Requires the default or a named CORS policy for the annotated handler(s). | -| `[RequireHost(params string[] hosts)]` | Class or Method | Restricts the handler(s) to the specified allowed hostnames. | -| `[RequireRateLimiting(string policyName)]` | Class or Method | Enforces the named rate-limiting policy on the annotated handler(s). | -| `[ShortCircuit]` | Class or Method | Marks the handler(s) to short-circuit the request pipeline when invoked. | -| `[Summary(string summary)]` | Class or Method | Sets the summary metadata applied to the generated endpoint(s). | -| `[Tags(params string[] tags)]` | Class or Method | Assigns OpenAPI tags to the annotated handler(s) for grouping in API docs. | +| Definition | Usage | Description | +|---------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------|----------------------------------------------------------------------------------------------------------------| +| `[Accepts(string contentType = "application/json", params string[] additionalContentTypes, RequestType = null, IsOptional = false)]` | Method | Declares the accepted request body CLR type, optional status, and list of content types for the handler. | +| `[Accepts(string contentType = "application/json", params string[] additionalContentTypes, IsOptional = false)]` | Method | Generic shortcut for specifying the request type and accepted content types for the handler. | +| `[AllowAnonymous]` | Class or Method | Allows the annotated endpoint or class to bypass authorization requirements. | +| `[Description(string description)]` | Method | Sets the OpenAPI description metadata for the generated endpoint. | +| `[DisableAntiforgery]` | Class or Method | Disables antiforgery protection for the annotated endpoint(s). | +| `[DisableRequestTimeout]` | Class or Method | Disables request timeout enforcement for the annotated endpoint(s). | +| `[DisableValidation]` | Class or Method | Disables automatic request validation (when supported) for the annotated endpoint(s). | +| `[DisplayName(string displayName)]` | Method | Overrides the endpoint display name used in diagnostics and metadata. | +| `[EndpointFilter(Type filterType)]` | Class or Method | Adds the specified endpoint filter type to the handler pipeline. | +| `[EndpointFilter]` | Class or Method | Generic form for registering an endpoint filter type on the handler pipeline. | +| `[ExcludeFromDescription]` | Class or Method | Hides the endpoint or class from generated API descriptions (e.g., OpenAPI). | +| `[MapConnect(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP CONNECT endpoint using the supplied route pattern and optional name. | +| `[MapDelete(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP DELETE endpoint using the supplied route pattern and optional name. | +| `[MapFallback(string pattern = "", Name = null)]` | Method | Maps the method as the fallback endpoint invoked when no other route matches. | +| `[MapGroup(string pattern, Name = null)]` | Class | Assigns a route group pattern and optional endpoint group name to every handler in the class. | +| `[MapGet(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP GET endpoint using the supplied route pattern and optional name. | +| `[MapHead(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP HEAD endpoint using the supplied route pattern and optional name. | +| `[MapOptions(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP OPTIONS endpoint using the supplied route pattern and optional name. | +| `[MapPatch(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP PATCH endpoint using the supplied route pattern and optional name. | +| `[MapPost(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP POST endpoint using the supplied route pattern and optional name. | +| `[MapPut(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP PUT endpoint using the supplied route pattern and optional name. | +| `[MapQuery(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP QUERY endpoint using the supplied route pattern and optional name. | +| `[MapTrace(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP TRACE endpoint using the supplied route pattern and optional name. | +| `[Order(int order)]` | Method | Controls the order in which endpoint conventions are applied to the handler. | +| `[ProducesProblem(int statusCode = StatusCodes.Status500InternalServerError, string? contentType = null, params string[] additionalContentTypes)]` | Method | Declares that the endpoint emits a problem details payload for the given status code and content types. | +| `[ProducesResponse(int statusCode = StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes, ResponseType = null)]` | Method | Declares response metadata for the handler, including status code, optional CLR type, and content types. | +| `[ProducesResponse(int statusCode = StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes)]` | Method | Generic shorthand for declaring the CLR response type along with status code and content types. | +| `[ProducesValidationProblem(int statusCode = StatusCodes.Status400BadRequest, string? contentType = null, params string[] additionalContentTypes)]` | Method | Declares that the endpoint returns validation problem details for the specified status code and content types. | +| `[RequestTimeout(string? policyName = null)]` | Class or Method | Applies the default or a named request-timeout policy to the handler(s). | +| `[RequireAuthorization(params string[] policies)]` | Class or Method | Enforces authorization on the handler(s), optionally scoping access to specific policies. | +| `[RequireCors(string? policyName = null)]` | Class or Method | Requires the default or a named CORS policy for the annotated handler(s). | +| `[RequireHost(params string[] hosts)]` | Class or Method | Restricts the handler(s) to the specified allowed hostnames. | +| `[RequireRateLimiting(string policyName)]` | Class or Method | Enforces the named rate-limiting policy on the annotated handler(s). | +| `[ShortCircuit]` | Class or Method | Marks the handler(s) to short-circuit the request pipeline when invoked. | +| `[Summary(string summary)]` | Class or Method | Sets the summary metadata applied to the generated endpoint(s). | +| `[Tags(params string[] tags)]` | Class or Method | Assigns OpenAPI tags to the annotated handler(s) for grouping in API docs. | diff --git a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs index a14d6a1..1a10e01 100644 --- a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs +++ b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs @@ -41,7 +41,9 @@ internal static class AttributeDataExtensions return normalized.ToEquatableImmutableArray(); } else if (arg.Value is string singleHost && !string.IsNullOrWhiteSpace(singleHost)) + { return new[] { singleHost.Trim() }.ToEquatableImmutableArray(); + } return null; } diff --git a/src/GeneratedEndpoints/Common/ConfigureMethodDetails.cs b/src/GeneratedEndpoints/Common/ConfigureMethodDetails.cs index 0e3b859..0a9d741 100644 --- a/src/GeneratedEndpoints/Common/ConfigureMethodDetails.cs +++ b/src/GeneratedEndpoints/Common/ConfigureMethodDetails.cs @@ -1,6 +1,3 @@ namespace GeneratedEndpoints.Common; -internal readonly record struct ConfigureMethodDetails( - bool HasConfigureMethod, - bool ConfigureMethodAcceptsServiceProvider -); +internal readonly record struct ConfigureMethodDetails(bool HasConfigureMethod, bool ConfigureMethodAcceptsServiceProvider); diff --git a/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs index 534d742..6b6d8e5 100644 --- a/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs +++ b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs @@ -7,18 +7,18 @@ namespace GeneratedEndpoints.Common; internal static partial class Constants { internal static readonly string FileHeader = $""" - //----------------------------------------------------------------------------- - // - // This code was generated by {nameof(MinimalApiGenerator)} which can be found - // in the {typeof(MinimalApiGenerator).Namespace} namespace. - // - // Changes to this file may cause incorrect behavior - // and will be lost if the code is regenerated. - // - //----------------------------------------------------------------------------- - - #nullable enable - """; + //----------------------------------------------------------------------------- + // + // This code was generated by {nameof(MinimalApiGenerator)} which can be found + // in the {typeof(MinimalApiGenerator).Namespace} namespace. + // + // Changes to this file may cause incorrect behavior + // and will be lost if the code is regenerated. + // + //----------------------------------------------------------------------------- + + #nullable enable + """; internal static readonly ImmutableArray HttpAttributeDefinitions = [ @@ -39,509 +39,552 @@ internal static partial class Constants HttpAttributeDefinitions.ToImmutableDictionary(static definition => definition.Name); internal static readonly SourceText RequireAuthorizationAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; + + /// + /// Specifies that authorization is required for the annotated endpoint or class. + /// Optionally restricts access to the specified authorization policies. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{RequireAuthorizationAttributeName}} : global::System.Attribute + { + /// + /// Gets the policy names that the endpoint or class requires. + /// + public string[] PolicyNames { get; } + + /// + /// Marks the endpoint or class as requiring authorization. + /// + public {{RequireAuthorizationAttributeName}}() + { + PolicyNames = []; + } + + /// + /// Marks the endpoint or class as requiring authorization with one or more policies. + /// + public {{RequireAuthorizationAttributeName}}(params string[] policyNames) + { + PolicyNames = policyNames ?? []; + } + } + """, Encoding.UTF8 + ); + + internal static readonly SourceText RequireCorsAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; + + /// + /// Specifies that the annotated endpoint requires a configured CORS policy. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{RequireCorsAttributeName}} : global::System.Attribute + { + /// + /// Gets the optional CORS policy name. + /// + public string? PolicyName { get; } + + /// + /// Marks the endpoint or class as requiring the default CORS policy. + /// + public {{RequireCorsAttributeName}}() + { + } + + /// + /// Marks the endpoint or class as requiring the specified named CORS policy. + /// + public {{RequireCorsAttributeName}}(string policyName) + { + PolicyName = policyName; + } + } + """, Encoding.UTF8 + ); + + internal static readonly SourceText RequireRateLimitingAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; /// - /// Specifies that authorization is required for the annotated endpoint or class. - /// Optionally restricts access to the specified authorization policies. + /// Specifies that the annotated endpoint requires the provided rate limiting policy. /// [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{RequireAuthorizationAttributeName}} : global::System.Attribute + internal sealed class {{RequireRateLimitingAttributeName}} : global::System.Attribute { /// - /// Gets the policy names that the endpoint or class requires. - /// - public string[] PolicyNames { get; } - - /// - /// Marks the endpoint or class as requiring authorization. + /// Initializes a new instance of the class. /// - public {{RequireAuthorizationAttributeName}}() + /// The rate limiting policy to apply. + public {{RequireRateLimitingAttributeName}}(string policyName) { - PolicyNames = []; + PolicyName = policyName; } /// - /// Marks the endpoint or class as requiring authorization with one or more policies. + /// Gets the rate limiting policy name. /// - public {{RequireAuthorizationAttributeName}}(params string[] policyNames) - { - PolicyNames = policyNames ?? []; - } + public string PolicyName { get; } } """, Encoding.UTF8 ); - internal static readonly SourceText RequireCorsAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; + internal static readonly SourceText RequireHostAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - /// - /// Specifies that the annotated endpoint requires a configured CORS policy. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{RequireCorsAttributeName}} : global::System.Attribute - { - /// - /// Gets the optional CORS policy name. - /// - public string? PolicyName { get; } + namespace {{AttributesNamespace}}; - /// - /// Marks the endpoint or class as requiring the default CORS policy. - /// - public {{RequireCorsAttributeName}}() - { - } + /// + /// Specifies the allowed hosts for the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{RequireHostAttributeName}} : global::System.Attribute + { + /// + /// Initializes a new instance of the class. + /// + /// The hosts that are allowed to access the endpoint. + public {{RequireHostAttributeName}}(params string[] hosts) + { + Hosts = hosts ?? []; + } - /// - /// Marks the endpoint or class as requiring the specified named CORS policy. - /// - public {{RequireCorsAttributeName}}(string policyName) - { - PolicyName = policyName; - } - } - """, Encoding.UTF8 + /// + /// Gets the allowed hosts. + /// + public string[] Hosts { get; } + } + """, Encoding.UTF8 ); - internal static readonly SourceText RequireRateLimitingAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText DisableAntiforgeryAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; /// - /// Specifies that the annotated endpoint requires the provided rate limiting policy. + /// Disables antiforgery protection for the annotated endpoint or class. /// [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{RequireRateLimitingAttributeName}} : global::System.Attribute + internal sealed class {{DisableAntiforgeryAttributeName}} : global::System.Attribute { - /// - /// Initializes a new instance of the class. - /// - /// The rate limiting policy to apply. - public {{RequireRateLimitingAttributeName}}(string policyName) - { - PolicyName = policyName; - } - - /// - /// Gets the rate limiting policy name. - /// - public string PolicyName { get; } } + """, Encoding.UTF8 ); - internal static readonly SourceText RequireHostAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + internal static readonly SourceText ShortCircuitAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - namespace {{AttributesNamespace}}; + namespace {{AttributesNamespace}}; - /// - /// Specifies the allowed hosts for the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{RequireHostAttributeName}} : global::System.Attribute - { - /// - /// Initializes a new instance of the class. - /// - /// The hosts that are allowed to access the endpoint. - public {{RequireHostAttributeName}}(params string[] hosts) - { - Hosts = hosts ?? []; - } + /// + /// Marks the annotated endpoint or class to short-circuit the request pipeline. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{ShortCircuitAttributeName}} : global::System.Attribute + { + } - /// - /// Gets the allowed hosts. - /// - public string[] Hosts { get; } - } - """, Encoding.UTF8 + """, Encoding.UTF8 ); - internal static readonly SourceText DisableAntiforgeryAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText DisableRequestTimeoutAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; + + /// + /// Disables the request timeout for the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{DisableRequestTimeoutAttributeName}} : global::System.Attribute + { + } + + """, Encoding.UTF8 + ); + + internal static readonly SourceText DisableValidationAttributeSourceText = SourceText.From($$""" + #if NET10_0_OR_GREATER {{FileHeader}} namespace {{AttributesNamespace}}; /// - /// Disables antiforgery protection for the annotated endpoint or class. + /// Disables request validation for the annotated endpoint or class when targeting .NET 10 or later. /// [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{DisableAntiforgeryAttributeName}} : global::System.Attribute + internal sealed class {{DisableValidationAttributeName}} : global::System.Attribute { } + #endif """, Encoding.UTF8 ); - internal static readonly SourceText ShortCircuitAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; + internal static readonly SourceText RequestTimeoutAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - /// - /// Marks the annotated endpoint or class to short-circuit the request pipeline. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{ShortCircuitAttributeName}} : global::System.Attribute - { - } + namespace {{AttributesNamespace}}; - """, Encoding.UTF8 - ); + /// + /// Applies the request timeout metadata to the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{RequestTimeoutAttributeName}} : global::System.Attribute + { + /// + /// Gets the optional request timeout policy name. + /// + public string? PolicyName { get; init; } - internal static readonly SourceText DisableRequestTimeoutAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + /// + /// Applies the default request timeout behavior. + /// + public {{RequestTimeoutAttributeName}}() + { + } - namespace {{AttributesNamespace}}; + /// + /// Applies the specified request timeout policy. + /// + /// The request timeout policy name. + public {{RequestTimeoutAttributeName}}(string policyName) + { + PolicyName = policyName; + } + } - /// - /// Disables the request timeout for the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{DisableRequestTimeoutAttributeName}} : global::System.Attribute - { - } + """, Encoding.UTF8 + ); - """, Encoding.UTF8 + internal static readonly SourceText OrderAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; + + /// + /// Specifies the order for the annotated endpoint when building conventions. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{OrderAttributeName}} : global::System.Attribute + { + /// + /// Gets the order that will be applied to the endpoint. + /// + public int Order { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The order value to apply to the endpoint. + public {{OrderAttributeName}}(int order) + { + Order = order; + } + } + + """, Encoding.UTF8 ); - internal static readonly SourceText DisableValidationAttributeSourceText = SourceText.From($$""" - #if NET10_0_OR_GREATER - {{FileHeader}} + internal static readonly SourceText MapGroupAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - namespace {{AttributesNamespace}}; + namespace {{AttributesNamespace}}; - /// - /// Disables request validation for the annotated endpoint or class when targeting .NET 10 or later. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{DisableValidationAttributeName}} : global::System.Attribute - { - } - #endif + /// + /// Specifies the route group for the annotated class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class, Inherited = false, AllowMultiple = false)] + internal sealed class {{MapGroupAttributeName}} : global::System.Attribute + { + /// + /// Gets the route group pattern. + /// + public string Pattern { get; } + + /// + /// Gets or sets the endpoint group name. + /// + public string? Name { get; init; } + + /// + /// Initializes a new instance of the class. + /// + /// The route group pattern to apply. + public {{MapGroupAttributeName}}(string pattern) + { + Pattern = pattern; + } + } - """, Encoding.UTF8 + """, Encoding.UTF8 ); - internal static readonly SourceText RequestTimeoutAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; + internal static readonly SourceText SummaryAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - /// - /// Applies the request timeout metadata to the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{RequestTimeoutAttributeName}} : global::System.Attribute - { - /// - /// Gets the optional request timeout policy name. - /// - public string? PolicyName { get; init; } + namespace {{AttributesNamespace}}; - /// - /// Applies the default request timeout behavior. - /// - public {{RequestTimeoutAttributeName}}() - { - } + /// + /// Specifies the summary metadata for the annotated endpoint. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{SummaryAttributeName}} : global::System.Attribute + { + /// + /// Gets the summary value for the endpoint. + /// + public string Summary { get; } - /// - /// Applies the specified request timeout policy. - /// - /// The request timeout policy name. - public {{RequestTimeoutAttributeName}}(string policyName) - { - PolicyName = policyName; - } - } + /// + /// Initializes a new instance of the class. + /// + /// The summary to apply to the endpoint. + public {{SummaryAttributeName}}(string summary) + { + Summary = summary; + } + } - """, Encoding.UTF8 + """, Encoding.UTF8 ); - internal static readonly SourceText OrderAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + internal static readonly SourceText AcceptsAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - namespace {{AttributesNamespace}}; + namespace {{AttributesNamespace}}; - /// - /// Specifies the order for the annotated endpoint when building conventions. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{OrderAttributeName}} : global::System.Attribute - { - /// - /// Gets the order that will be applied to the endpoint. - /// - public int Order { get; } + /// + /// Specifies the request type and content types accepted by the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{AcceptsAttributeName}} : global::System.Attribute + { + /// + /// Gets the request type accepted by the endpoint. + /// + public global::System.Type? RequestType { get; init; } - /// - /// Initializes a new instance of the class. - /// - /// The order value to apply to the endpoint. - public {{OrderAttributeName}}(int order) - { - Order = order; - } - } + /// + /// Gets a value indicating whether the request body is optional. + /// + public bool IsOptional { get; init; } - """, Encoding.UTF8 - ); + /// + /// Gets the primary content type accepted by the endpoint. + /// + public string ContentType { get; } - internal static readonly SourceText MapGroupAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + /// + /// Gets the additional content types accepted by the endpoint. + /// + public string[] AdditionalContentTypes { get; } - namespace {{AttributesNamespace}}; + /// + /// Initializes a new instance of the class. + /// + /// The primary content type accepted by the endpoint. + /// Additional content types accepted by the endpoint. + public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) + { + ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; + AdditionalContentTypes = additionalContentTypes ?? []; + } + } /// - /// Specifies the route group for the annotated class. + /// Specifies the request type using a generic argument and the content types accepted by the annotated endpoint or class. /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class, Inherited = false, AllowMultiple = false)] - internal sealed class {{MapGroupAttributeName}} : global::System.Attribute + /// The CLR type of the request body. + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{AcceptsAttributeName}} : global::System.Attribute { /// - /// Gets the route group pattern. + /// Gets the request type accepted by the endpoint. + /// + public global::System.Type RequestType => typeof(TRequest); + + /// + /// Gets a value indicating whether the request body is optional. + /// + public bool IsOptional { get; init; } + + /// + /// Gets the primary content type accepted by the endpoint. /// - public string Pattern { get; } + public string ContentType { get; } /// - /// Gets or sets the endpoint group name. + /// Gets the additional content types accepted by the endpoint. /// - public string? Name { get; init; } + public string[] AdditionalContentTypes { get; } /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the generic Accepts attribute class. /// - /// The route group pattern to apply. - public {{MapGroupAttributeName}}(string pattern) + /// The primary content type accepted by the endpoint. + /// Additional content types accepted by the endpoint. + public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) { - Pattern = pattern; + ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; + AdditionalContentTypes = additionalContentTypes ?? []; } } """, Encoding.UTF8 ); - internal static readonly SourceText SummaryAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; - - /// - /// Specifies the summary metadata for the annotated endpoint. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{SummaryAttributeName}} : global::System.Attribute - { - /// - /// Gets the summary value for the endpoint. - /// - public string Summary { get; } - - /// - /// Initializes a new instance of the class. - /// - /// The summary to apply to the endpoint. - public {{SummaryAttributeName}}(string summary) - { - Summary = summary; - } - } - - """, Encoding.UTF8 - ); - - internal static readonly SourceText AcceptsAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; - - /// - /// Specifies the request type and content types accepted by the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{AcceptsAttributeName}} : global::System.Attribute - { - /// - /// Gets the request type accepted by the endpoint. - /// - public global::System.Type? RequestType { get; init; } - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } - - /// - /// Gets the primary content type accepted by the endpoint. - /// - public string ContentType { get; } - - /// - /// Gets the additional content types accepted by the endpoint. - /// - public string[] AdditionalContentTypes { get; } - - /// - /// Initializes a new instance of the class. - /// - /// The primary content type accepted by the endpoint. - /// Additional content types accepted by the endpoint. - public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) - { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; - } - } - - /// - /// Specifies the request type using a generic argument and the content types accepted by the annotated endpoint or class. - /// - /// The CLR type of the request body. - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{AcceptsAttributeName}} : global::System.Attribute - { - /// - /// Gets the request type accepted by the endpoint. - /// - public global::System.Type RequestType => typeof(TRequest); - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } - - /// - /// Gets the primary content type accepted by the endpoint. - /// - public string ContentType { get; } - - /// - /// Gets the additional content types accepted by the endpoint. - /// - public string[] AdditionalContentTypes { get; } - - /// - /// Initializes a new instance of the generic Accepts attribute class. - /// - /// The primary content type accepted by the endpoint. - /// Additional content types accepted by the endpoint. - public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) - { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; - } - } - - """, Encoding.UTF8 - ); - internal static readonly SourceText EndpointFilterAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + {{FileHeader}} - namespace {{AttributesNamespace}}; + namespace {{AttributesNamespace}}; - /// - /// Specifies an endpoint filter type to apply to the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{EndpointFilterAttributeName}} : global::System.Attribute - { - /// - /// Gets the CLR type of the endpoint filter. - /// - public global::System.Type FilterType { get; } + /// + /// Specifies an endpoint filter type to apply to the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{EndpointFilterAttributeName}} : global::System.Attribute + { + /// + /// Gets the CLR type of the endpoint filter. + /// + public global::System.Type FilterType { get; } - /// - /// Initializes a new instance of the class. - /// - /// The CLR type of the endpoint filter. - public {{EndpointFilterAttributeName}}(global::System.Type filterType) - { - FilterType = filterType ?? throw new global::System.ArgumentNullException(nameof(filterType)); - } - } - - /// - /// Specifies an endpoint filter type using a generic argument. - /// - /// The CLR type of the endpoint filter. - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{EndpointFilterAttributeName}} : global::System.Attribute - { - /// - /// Gets the CLR type of the endpoint filter. - /// - public global::System.Type FilterType => typeof(TFilter); - } + /// + /// Initializes a new instance of the class. + /// + /// The CLR type of the endpoint filter. + public {{EndpointFilterAttributeName}}(global::System.Type filterType) + { + FilterType = filterType ?? throw new global::System.ArgumentNullException(nameof(filterType)); + } + } - """, Encoding.UTF8 + /// + /// Specifies an endpoint filter type using a generic argument. + /// + /// The CLR type of the endpoint filter. + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{EndpointFilterAttributeName}} : global::System.Attribute + { + /// + /// Gets the CLR type of the endpoint filter. + /// + public global::System.Type FilterType => typeof(TFilter); + } + + """, Encoding.UTF8 ); internal static readonly SourceText ProducesResponseAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + {{FileHeader}} - namespace {{AttributesNamespace}}; + namespace {{AttributesNamespace}}; - /// - /// Specifies a response type, status code, and content types produced by the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute - { - /// - /// Gets the response type produced by the endpoint. - /// - public global::System.Type? ResponseType { get; init; } + /// + /// Specifies a response type, status code, and content types produced by the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute + { + /// + /// Gets the response type produced by the endpoint. + /// + public global::System.Type? ResponseType { get; init; } - /// - /// Gets the HTTP status code returned by the endpoint. - /// - public int StatusCode { get; } + /// + /// Gets the HTTP status code returned by the endpoint. + /// + public int StatusCode { get; } - /// - /// Gets the primary content type produced by the endpoint. - /// - public string? ContentType { get; } + /// + /// Gets the primary content type produced by the endpoint. + /// + public string? ContentType { get; } - /// - /// Gets the additional content types produced by the endpoint. - /// - public string[] AdditionalContentTypes { get; } + /// + /// Gets the additional content types produced by the endpoint. + /// + public string[] AdditionalContentTypes { get; } - /// - /// Initializes a new instance of the class. - /// - /// The HTTP status code returned by the endpoint. - /// The primary content type produced by the endpoint. - /// Additional content types produced by the endpoint. - public {{ProducesResponseAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) - { - StatusCode = statusCode; - ContentType = contentType; - AdditionalContentTypes = additionalContentTypes ?? []; - } - } + /// + /// Initializes a new instance of the class. + /// + /// The HTTP status code returned by the endpoint. + /// The primary content type produced by the endpoint. + /// Additional content types produced by the endpoint. + public {{ProducesResponseAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) + { + StatusCode = statusCode; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes ?? []; + } + } + + /// + /// Specifies a response type using a generic argument along with status code and content types produced by the annotated endpoint or class. + /// + /// The CLR type of the response body. + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute + { + /// + /// Gets the response type produced by the endpoint. + /// + public global::System.Type ResponseType => typeof(TResponse); + + /// + /// Gets the HTTP status code returned by the endpoint. + /// + public int StatusCode { get; } + + /// + /// Gets the primary content type produced by the endpoint. + /// + public string? ContentType { get; } + + /// + /// Gets the additional content types produced by the endpoint. + /// + public string[] AdditionalContentTypes { get; } + + /// + /// Initializes a new instance of the generic Produces attribute class. + /// + /// The HTTP status code returned by the endpoint. + /// The primary content type produced by the endpoint. + /// Additional content types produced by the endpoint. + public {{ProducesResponseAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) + { + StatusCode = statusCode; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes ?? []; + } + } + + """, Encoding.UTF8 + ); + + internal static readonly SourceText ProducesProblemAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; /// - /// Specifies a response type using a generic argument along with status code and content types produced by the annotated endpoint or class. + /// Specifies that the endpoint produces a problem details payload. /// - /// The CLR type of the response body. [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute + internal sealed class {{ProducesProblemAttributeName}} : global::System.Attribute { - /// - /// Gets the response type produced by the endpoint. - /// - public global::System.Type ResponseType => typeof(TResponse); - /// /// Gets the HTTP status code returned by the endpoint. /// @@ -558,12 +601,12 @@ internal sealed class {{ProducesResponseAttributeName}} : global::Sys public string[] AdditionalContentTypes { get; } /// - /// Initializes a new instance of the generic Produces attribute class. + /// Initializes a new instance of the class. /// /// The HTTP status code returned by the endpoint. /// The primary content type produced by the endpoint. /// Additional content types produced by the endpoint. - public {{ProducesResponseAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) + public {{ProducesProblemAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status500InternalServerError, string? contentType = null, params string[] additionalContentTypes) { StatusCode = statusCode; ContentType = contentType; @@ -574,90 +617,47 @@ internal sealed class {{ProducesResponseAttributeName}} : global::Sys """, Encoding.UTF8 ); - internal static readonly SourceText ProducesProblemAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; - - /// - /// Specifies that the endpoint produces a problem details payload. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{ProducesProblemAttributeName}} : global::System.Attribute - { - /// - /// Gets the HTTP status code returned by the endpoint. - /// - public int StatusCode { get; } - - /// - /// Gets the primary content type produced by the endpoint. - /// - public string? ContentType { get; } - - /// - /// Gets the additional content types produced by the endpoint. - /// - public string[] AdditionalContentTypes { get; } - - /// - /// Initializes a new instance of the class. - /// - /// The HTTP status code returned by the endpoint. - /// The primary content type produced by the endpoint. - /// Additional content types produced by the endpoint. - public {{ProducesProblemAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status500InternalServerError, string? contentType = null, params string[] additionalContentTypes) - { - StatusCode = statusCode; - ContentType = contentType; - AdditionalContentTypes = additionalContentTypes ?? []; - } - } - - """, Encoding.UTF8 - ); - internal static readonly SourceText ProducesValidationProblemAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; - - /// - /// Specifies that the endpoint produces a validation problem details payload. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{ProducesValidationProblemAttributeName}} : global::System.Attribute - { - /// - /// Gets the HTTP status code returned by the endpoint. - /// - public int StatusCode { get; } - - /// - /// Gets the primary content type produced by the endpoint. - /// - public string? ContentType { get; } - - /// - /// Gets the additional content types produced by the endpoint. - /// - public string[] AdditionalContentTypes { get; } - - /// - /// Initializes a new instance of the class. - /// - /// The HTTP status code returned by the endpoint. - /// The primary content type produced by the endpoint. - /// Additional content types produced by the endpoint. - public {{ProducesValidationProblemAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status400BadRequest, string? contentType = null, params string[] additionalContentTypes) - { - StatusCode = statusCode; - ContentType = contentType; - AdditionalContentTypes = additionalContentTypes ?? []; - } - } - - """, Encoding.UTF8 + {{FileHeader}} + + namespace {{AttributesNamespace}}; + + /// + /// Specifies that the endpoint produces a validation problem details payload. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{ProducesValidationProblemAttributeName}} : global::System.Attribute + { + /// + /// Gets the HTTP status code returned by the endpoint. + /// + public int StatusCode { get; } + + /// + /// Gets the primary content type produced by the endpoint. + /// + public string? ContentType { get; } + + /// + /// Gets the additional content types produced by the endpoint. + /// + public string[] AdditionalContentTypes { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The HTTP status code returned by the endpoint. + /// The primary content type produced by the endpoint. + /// Additional content types produced by the endpoint. + public {{ProducesValidationProblemAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status400BadRequest, string? contentType = null, params string[] additionalContentTypes) + { + StatusCode = statusCode; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes ?? []; + } + } + + """, Encoding.UTF8 ); private static HttpAttributeDefinition CreateHttpAttributeDefinition(string attributeName, string verb, bool allowOptionalPattern = false) diff --git a/src/GeneratedEndpoints/Common/Constants.cs b/src/GeneratedEndpoints/Common/Constants.cs index 05235a2..8ef77d1 100644 --- a/src/GeneratedEndpoints/Common/Constants.cs +++ b/src/GeneratedEndpoints/Common/Constants.cs @@ -4,9 +4,6 @@ namespace GeneratedEndpoints.Common; internal static partial class Constants { - private const string BaseNamespace = "Microsoft.AspNetCore.Generated"; - private const string AttributesNamespace = $"{BaseNamespace}.Attributes"; - internal const string FallbackHttpMethod = "__FALLBACK__"; internal const string NameAttributeNamedParameter = "Name"; @@ -15,51 +12,39 @@ internal static partial class Constants internal const string IsOptionalAttributeNamedParameter = "IsOptional"; internal const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute"; - private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}"; internal const string RequireAuthorizationAttributeHint = $"{RequireAuthorizationAttributeFullyQualifiedName}.gs.cs"; internal const string RequireCorsAttributeName = "RequireCorsAttribute"; - private const string RequireCorsAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireCorsAttributeName}"; internal const string RequireCorsAttributeHint = $"{RequireCorsAttributeFullyQualifiedName}.gs.cs"; internal const string RequireRateLimitingAttributeName = "RequireRateLimitingAttribute"; - private const string RequireRateLimitingAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireRateLimitingAttributeName}"; internal const string RequireRateLimitingAttributeHint = $"{RequireRateLimitingAttributeFullyQualifiedName}.gs.cs"; internal const string RequireHostAttributeName = "RequireHostAttribute"; - private const string RequireHostAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireHostAttributeName}"; internal const string RequireHostAttributeHint = $"{RequireHostAttributeFullyQualifiedName}.gs.cs"; internal const string DisableAntiforgeryAttributeName = "DisableAntiforgeryAttribute"; - private const string DisableAntiforgeryAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableAntiforgeryAttributeName}"; internal const string DisableAntiforgeryAttributeHint = $"{DisableAntiforgeryAttributeFullyQualifiedName}.gs.cs"; internal const string ShortCircuitAttributeName = "ShortCircuitAttribute"; - private const string ShortCircuitAttributeFullyQualifiedName = $"{AttributesNamespace}.{ShortCircuitAttributeName}"; internal const string ShortCircuitAttributeHint = $"{ShortCircuitAttributeFullyQualifiedName}.gs.cs"; internal const string DisableRequestTimeoutAttributeName = "DisableRequestTimeoutAttribute"; - private const string DisableRequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableRequestTimeoutAttributeName}"; internal const string DisableRequestTimeoutAttributeHint = $"{DisableRequestTimeoutAttributeFullyQualifiedName}.gs.cs"; internal const string DisableValidationAttributeName = "DisableValidationAttribute"; - private const string DisableValidationAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableValidationAttributeName}"; internal const string DisableValidationAttributeHint = $"{DisableValidationAttributeFullyQualifiedName}.gs.cs"; internal const string RequestTimeoutAttributeName = "RequestTimeoutAttribute"; - private const string RequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequestTimeoutAttributeName}"; internal const string RequestTimeoutAttributeHint = $"{RequestTimeoutAttributeFullyQualifiedName}.gs.cs"; internal const string OrderAttributeName = "OrderAttribute"; - private const string OrderAttributeFullyQualifiedName = $"{AttributesNamespace}.{OrderAttributeName}"; internal const string OrderAttributeHint = $"{OrderAttributeFullyQualifiedName}.gs.cs"; internal const string MapGroupAttributeName = "MapGroupAttribute"; - private const string MapGroupAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapGroupAttributeName}"; internal const string MapGroupAttributeHint = $"{MapGroupAttributeFullyQualifiedName}.gs.cs"; internal const string SummaryAttributeName = "SummaryAttribute"; - private const string SummaryAttributeFullyQualifiedName = $"{AttributesNamespace}.{SummaryAttributeName}"; internal const string SummaryAttributeHint = $"{SummaryAttributeFullyQualifiedName}.gs.cs"; internal const string DisplayNameAttributeName = nameof(DisplayNameAttribute); @@ -69,35 +54,28 @@ internal static partial class Constants internal const string ExcludeFromDescriptionAttributeName = "ExcludeFromDescriptionAttribute"; internal const string EndpointFilterAttributeName = "EndpointFilterAttribute"; - private const string EndpointFilterAttributeFullyQualifiedName = $"{AttributesNamespace}.{EndpointFilterAttributeName}"; internal const string EndpointFilterAttributeHint = $"{EndpointFilterAttributeFullyQualifiedName}.gs.cs"; internal const string AcceptsAttributeName = "AcceptsAttribute"; - private const string AcceptsAttributeFullyQualifiedName = $"{AttributesNamespace}.{AcceptsAttributeName}"; internal const string AcceptsAttributeHint = $"{AcceptsAttributeFullyQualifiedName}.gs.cs"; internal const string ProducesResponseAttributeName = "ProducesResponseAttribute"; - private const string ProducesResponseAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesResponseAttributeName}"; internal const string ProducesResponseAttributeHint = $"{ProducesResponseAttributeFullyQualifiedName}.gs.cs"; internal const string ProducesProblemAttributeName = "ProducesProblemAttribute"; - private const string ProducesProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesProblemAttributeName}"; internal const string ProducesProblemAttributeHint = $"{ProducesProblemAttributeFullyQualifiedName}.gs.cs"; internal const string ProducesValidationProblemAttributeName = "ProducesValidationProblemAttribute"; - private const string ProducesValidationProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesValidationProblemAttributeName}"; internal const string ProducesValidationProblemAttributeHint = $"{ProducesValidationProblemAttributeFullyQualifiedName}.gs.cs"; internal const string RoutingNamespace = $"{BaseNamespace}.Routing"; internal const string AddEndpointHandlersClassName = "EndpointServicesExtensions"; internal const string AddEndpointHandlersMethodName = "AddEndpointHandlers"; - private const string AddEndpointHandlersMethodFullyQualifiedName = $"{RoutingNamespace}.{AddEndpointHandlersMethodName}"; internal const string AddEndpointHandlersMethodHint = $"{AddEndpointHandlersMethodFullyQualifiedName}.g.cs"; internal const string UseEndpointHandlersClassName = "EndpointRouteBuilderExtensions"; internal const string UseEndpointHandlersMethodName = "MapEndpointHandlers"; - private const string UseEndpointHandlersMethodFullyQualifiedName = $"{RoutingNamespace}.{UseEndpointHandlersMethodName}"; internal const string UseEndpointHandlersMethodHint = $"{UseEndpointHandlersMethodFullyQualifiedName}.g.cs"; internal const string ConfigureMethodName = "Configure"; @@ -113,4 +91,25 @@ internal static partial class Constants internal static readonly string[] AspNetCoreRoutingNamespaceParts = ["Microsoft", "AspNetCore", "Routing"]; internal static readonly string[] ExtensionsDependencyInjectionNamespaceParts = ["Microsoft", "Extensions", "DependencyInjection"]; internal static readonly string[] ComponentModelNamespaceParts = ["System", "ComponentModel"]; + private const string BaseNamespace = "Microsoft.AspNetCore.Generated"; + private const string AttributesNamespace = $"{BaseNamespace}.Attributes"; + private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}"; + private const string RequireCorsAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireCorsAttributeName}"; + private const string RequireRateLimitingAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireRateLimitingAttributeName}"; + private const string RequireHostAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireHostAttributeName}"; + private const string DisableAntiforgeryAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableAntiforgeryAttributeName}"; + private const string ShortCircuitAttributeFullyQualifiedName = $"{AttributesNamespace}.{ShortCircuitAttributeName}"; + private const string DisableRequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableRequestTimeoutAttributeName}"; + private const string DisableValidationAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableValidationAttributeName}"; + private const string RequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequestTimeoutAttributeName}"; + private const string OrderAttributeFullyQualifiedName = $"{AttributesNamespace}.{OrderAttributeName}"; + private const string MapGroupAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapGroupAttributeName}"; + private const string SummaryAttributeFullyQualifiedName = $"{AttributesNamespace}.{SummaryAttributeName}"; + private const string EndpointFilterAttributeFullyQualifiedName = $"{AttributesNamespace}.{EndpointFilterAttributeName}"; + private const string AcceptsAttributeFullyQualifiedName = $"{AttributesNamespace}.{AcceptsAttributeName}"; + private const string ProducesResponseAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesResponseAttributeName}"; + private const string ProducesProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesProblemAttributeName}"; + private const string ProducesValidationProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesValidationProblemAttributeName}"; + private const string AddEndpointHandlersMethodFullyQualifiedName = $"{RoutingNamespace}.{AddEndpointHandlersMethodName}"; + private const string UseEndpointHandlersMethodFullyQualifiedName = $"{RoutingNamespace}.{UseEndpointHandlersMethodName}"; } diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index f949d1e..f986ac9 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -10,7 +10,6 @@ internal static class EndpointConfigurationFactory public static EndpointConfiguration Create(ISymbol symbol) { - var attributes = symbol.GetAttributes(); string? displayName = null; @@ -218,14 +217,15 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym { string? requestType; if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - requestType = attributeClass.TypeArguments[0].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + requestType = attributeClass.TypeArguments[0] + .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); else if (attribute.GetNamedTypeSymbol(RequestTypeAttributeNamedParameter) is { } requestTypeSymbol) requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); else return; var contentType = attribute.GetConstructorStringValue() ?? ApplicationJsonContentType; - var additionalContentTypes = attribute.GetConstructorStringArray(position: 1); + var additionalContentTypes = attribute.GetConstructorStringArray(1); var isOptional = attribute.GetNamedBoolValue(IsOptionalAttributeNamedParameter); var acceptMetadata = new AcceptsMetadata(requestType, contentType, additionalContentTypes, isOptional); @@ -238,15 +238,16 @@ private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSy { string? responseType; if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - responseType = attributeClass.TypeArguments[0].ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + responseType = attributeClass.TypeArguments[0] + .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); else if (attribute.GetNamedTypeSymbol(ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol) responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); else return; var statusCode = attribute.GetConstructorIntValue() ?? 200; - var contentType = attribute.GetConstructorStringValue(position: 1); - var additionalContentTypes = attribute.GetConstructorStringArray(position: 2); + var contentType = attribute.GetConstructorStringValue(1); + var additionalContentTypes = attribute.GetConstructorStringArray(2); var producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes); diff --git a/src/GeneratedEndpoints/Common/HttpAttributeDefinition.cs b/src/GeneratedEndpoints/Common/HttpAttributeDefinition.cs index e740ab6..6c757de 100644 --- a/src/GeneratedEndpoints/Common/HttpAttributeDefinition.cs +++ b/src/GeneratedEndpoints/Common/HttpAttributeDefinition.cs @@ -2,10 +2,4 @@ namespace GeneratedEndpoints.Common; -internal readonly record struct HttpAttributeDefinition( - string Name, - string FullyQualifiedName, - string Hint, - string Verb, - SourceText SourceText -); +internal readonly record struct HttpAttributeDefinition(string Name, string FullyQualifiedName, string Hint, string Verb, SourceText SourceText); diff --git a/src/GeneratedEndpoints/Common/IncrementalValueProviderExtensions.cs b/src/GeneratedEndpoints/Common/IncrementalValueProviderExtensions.cs index f2b97e7..46f3976 100644 --- a/src/GeneratedEndpoints/Common/IncrementalValueProviderExtensions.cs +++ b/src/GeneratedEndpoints/Common/IncrementalValueProviderExtensions.cs @@ -42,13 +42,14 @@ IncrementalValueProvider provider4 .Select((tuple, _) => (tuple.Left.Left.Left, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right)); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5)> Combine( - this IncrementalValueProvider provider1, - IncrementalValueProvider provider2, - IncrementalValueProvider provider3, - IncrementalValueProvider provider4, - IncrementalValueProvider provider5 - ) + public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5)> + Combine( + this IncrementalValueProvider provider1, + IncrementalValueProvider provider2, + IncrementalValueProvider provider3, + IncrementalValueProvider provider4, + IncrementalValueProvider provider5 + ) { return provider1.Combine(provider2) .Combine(provider3) @@ -57,32 +58,36 @@ IncrementalValueProvider provider5 .Select((tuple, _) => (tuple.Left.Left.Left.Left, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right)); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6)> Combine( - this IncrementalValueProvider provider1, - IncrementalValueProvider provider2, - IncrementalValueProvider provider3, - IncrementalValueProvider provider4, - IncrementalValueProvider provider5, - IncrementalValueProvider provider6 - ) + public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6)> + Combine( + this IncrementalValueProvider provider1, + IncrementalValueProvider provider2, + IncrementalValueProvider provider3, + IncrementalValueProvider provider4, + IncrementalValueProvider provider5, + IncrementalValueProvider provider6 + ) { return provider1.Combine(provider2) .Combine(provider3) .Combine(provider4) .Combine(provider5) .Combine(provider6) - .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right)); + .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, + tuple.Left.Right, tuple.Right) + ); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7)> Combine( - this IncrementalValueProvider provider1, - IncrementalValueProvider provider2, - IncrementalValueProvider provider3, - IncrementalValueProvider provider4, - IncrementalValueProvider provider5, - IncrementalValueProvider provider6, - IncrementalValueProvider provider7 - ) + public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7)> + Combine( + this IncrementalValueProvider provider1, + IncrementalValueProvider provider2, + IncrementalValueProvider provider3, + IncrementalValueProvider provider4, + IncrementalValueProvider provider5, + IncrementalValueProvider provider6, + IncrementalValueProvider provider7 + ) { return provider1.Combine(provider2) .Combine(provider3) @@ -90,19 +95,22 @@ IncrementalValueProvider provider7 .Combine(provider5) .Combine(provider6) .Combine(provider7) - .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right)); + .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Right, + tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) + ); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8)> Combine( - this IncrementalValueProvider provider1, - IncrementalValueProvider provider2, - IncrementalValueProvider provider3, - IncrementalValueProvider provider4, - IncrementalValueProvider provider5, - IncrementalValueProvider provider6, - IncrementalValueProvider provider7, - IncrementalValueProvider provider8 - ) + public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8)> + Combine( + this IncrementalValueProvider provider1, + IncrementalValueProvider provider2, + IncrementalValueProvider provider3, + IncrementalValueProvider provider4, + IncrementalValueProvider provider5, + IncrementalValueProvider provider6, + IncrementalValueProvider provider7, + IncrementalValueProvider provider8 + ) { return provider1.Combine(provider2) .Combine(provider3) @@ -111,12 +119,13 @@ IncrementalValueProvider provider8 .Combine(provider6) .Combine(provider7) .Combine(provider8) - .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, - tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) + .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Right, + tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) ); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8, TItem9 Item9)> + public static + IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8, TItem9 Item9)> Combine( this IncrementalValueProvider provider1, IncrementalValueProvider provider2, @@ -137,13 +146,15 @@ IncrementalValueProvider provider9 .Combine(provider7) .Combine(provider8) .Combine(provider9) - .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Right, - tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) + .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Left.Right, + tuple.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, + tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) ); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8, TItem9 Item9, TItem10 Item10)> - Combine( + public static + IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8, TItem9 Item9, + TItem10 Item10)> Combine( this IncrementalValueProvider provider1, IncrementalValueProvider provider2, IncrementalValueProvider provider3, @@ -165,8 +176,9 @@ IncrementalValueProvider provider10 .Combine(provider8) .Combine(provider9) .Combine(provider10) - .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Left.Right, - tuple.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) + .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Left.Left.Right, + tuple.Left.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Right, + tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) ); } } diff --git a/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs b/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs index dca5b1c..bb4cdfc 100644 --- a/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs +++ b/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs @@ -82,8 +82,7 @@ private static BindingSource GetBindingSourceFromAttributeClass(INamedTypeSymbol "FromBodyAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromBody, "FromFormAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromForm, "FromServicesAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromServices, - "FromKeyedServicesAttribute" when IsInNamespace(namespaceSymbol, ExtensionsDependencyInjectionNamespaceParts) - => BindingSource.FromKeyedServices, + "FromKeyedServicesAttribute" when IsInNamespace(namespaceSymbol, ExtensionsDependencyInjectionNamespaceParts) => BindingSource.FromKeyedServices, "AsParametersAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreHttpNamespaceParts) => BindingSource.AsParameters, _ => BindingSource.None, }; diff --git a/src/GeneratedEndpoints/Common/ProducesProblemMetadata.cs b/src/GeneratedEndpoints/Common/ProducesProblemMetadata.cs index 04baeaa..443f70b 100644 --- a/src/GeneratedEndpoints/Common/ProducesProblemMetadata.cs +++ b/src/GeneratedEndpoints/Common/ProducesProblemMetadata.cs @@ -1,7 +1,3 @@ namespace GeneratedEndpoints.Common; -internal readonly record struct ProducesProblemMetadata( - int StatusCode, - string? ContentType, - EquatableImmutableArray? AdditionalContentTypes -); +internal readonly record struct ProducesProblemMetadata(int StatusCode, string? ContentType, EquatableImmutableArray? AdditionalContentTypes); diff --git a/src/GeneratedEndpoints/Common/ProducesValidationProblemMetadata.cs b/src/GeneratedEndpoints/Common/ProducesValidationProblemMetadata.cs index 46bca0c..2234852 100644 --- a/src/GeneratedEndpoints/Common/ProducesValidationProblemMetadata.cs +++ b/src/GeneratedEndpoints/Common/ProducesValidationProblemMetadata.cs @@ -1,7 +1,3 @@ namespace GeneratedEndpoints.Common; -internal readonly record struct ProducesValidationProblemMetadata( - int StatusCode, - string? ContentType, - EquatableImmutableArray? AdditionalContentTypes -); +internal readonly record struct ProducesValidationProblemMetadata(int StatusCode, string? ContentType, EquatableImmutableArray? AdditionalContentTypes); diff --git a/src/GeneratedEndpoints/Common/RequestHandler.cs b/src/GeneratedEndpoints/Common/RequestHandler.cs index 6017b77..3ca8301 100644 --- a/src/GeneratedEndpoints/Common/RequestHandler.cs +++ b/src/GeneratedEndpoints/Common/RequestHandler.cs @@ -4,7 +4,6 @@ namespace GeneratedEndpoints.Common; internal record struct RequestHandler { - private string? _name; public required RequestHandlerClass Class { get; init; } public required RequestHandlerMethod Method { get; init; } public required string HttpMethod { get; init; } @@ -16,6 +15,8 @@ public required string? Name init => _name = value; } + private string? _name; + public void SetFullyQualifiedName() { ReadOnlySpan className = Class.Name; diff --git a/src/GeneratedEndpoints/Common/RequestHandlerAttributeKind.cs b/src/GeneratedEndpoints/Common/RequestHandlerAttributeKind.cs index 1c4ff22..1a80118 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerAttributeKind.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerAttributeKind.cs @@ -24,5 +24,5 @@ internal enum RequestHandlerAttributeKind Description, AllowAnonymous, Tags, - ExcludeFromDescription + ExcludeFromDescription, } diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs index 645073b..3226132 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs @@ -24,8 +24,7 @@ public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, Cancellatio var name = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); var isStatic = classSymbol.IsStatic; - var configureMethodDetails = GetConfigureMethodDetails(classSymbol, cancellationToken - ); + var configureMethodDetails = GetConfigureMethodDetails(classSymbol, cancellationToken); var classConfiguration = EndpointConfigurationFactory.Create(classSymbol); @@ -37,10 +36,7 @@ public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, Cancellatio } } - private static ConfigureMethodDetails GetConfigureMethodDetails( - INamedTypeSymbol classSymbol, - CancellationToken cancellationToken - ) + private static ConfigureMethodDetails GetConfigureMethodDetails(INamedTypeSymbol classSymbol, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -67,10 +63,7 @@ CancellationToken cancellationToken return new ConfigureMethodDetails(hasConfigureMethod, acceptsServiceProvider); } - private static bool IsConfigureMethod( - IMethodSymbol methodSymbol, - out bool acceptsServiceProvider - ) + private static bool IsConfigureMethod(IMethodSymbol methodSymbol, out bool acceptsServiceProvider) { acceptsServiceProvider = false; @@ -112,10 +105,7 @@ private static bool IsServiceProviderParameter(ITypeSymbol typeSymbol) return MatchesServiceProvider(typeSymbol); } - private static bool HasEndpointConventionBuilderConstraint( - ITypeParameterSymbol builderTypeParameter, - IMethodSymbol methodSymbol - ) + private static bool HasEndpointConventionBuilderConstraint(ITypeParameterSymbol builderTypeParameter, IMethodSymbol methodSymbol) { var symbolMatches = builderTypeParameter.ConstraintTypes.Any(MatchesEndpointConventionBuilder); if (symbolMatches) diff --git a/src/GeneratedEndpoints/GeneratedEndpoints.csproj b/src/GeneratedEndpoints/GeneratedEndpoints.csproj index 55bf32f..ca27350 100644 --- a/src/GeneratedEndpoints/GeneratedEndpoints.csproj +++ b/src/GeneratedEndpoints/GeneratedEndpoints.csproj @@ -20,12 +20,12 @@ - + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all @@ -36,9 +36,9 @@ GeneratedEndpoints - 10.0.0 - 10.0.0.0 - 10.0.0.0 + 10.0.1 + 10.0.1.0 + 10.0.1.0 en-US false diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 3cd6dd3..00181d4 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -1,4 +1,3 @@ -using System.Collections.Generic; using System.Collections.Immutable; using System.Runtime.CompilerServices; using GeneratedEndpoints.Common; diff --git a/tests/GeneratedEndpoints.Tests.Lab/GeneratedEndpoints.Tests.Lab.csproj b/tests/GeneratedEndpoints.Tests.Lab/GeneratedEndpoints.Tests.Lab.csproj index 7ab05ac..db2be7e 100644 --- a/tests/GeneratedEndpoints.Tests.Lab/GeneratedEndpoints.Tests.Lab.csproj +++ b/tests/GeneratedEndpoints.Tests.Lab/GeneratedEndpoints.Tests.Lab.csproj @@ -12,11 +12,11 @@ - + - - appsettings.json - + + appsettings.json + diff --git a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.cs b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.cs index 1831f55..9b27e5f 100644 --- a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.cs +++ b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.cs @@ -8,8 +8,8 @@ namespace GeneratedEndpoints.Tests; public class AttributeGenerationTests { private const string AttributeTestSource = "internal static class AttributeTestEndpoints { }"; - private static readonly GeneratorDriverRunResult GeneratorResult = - TestHelpers.RunGenerator(TestHelpers.GetSources(AttributeTestSource, withNamespace: true)); + + private static readonly GeneratorDriverRunResult GeneratorResult = TestHelpers.RunGenerator(TestHelpers.GetSources(AttributeTestSource, true)); public AttributeGenerationTests() { @@ -18,116 +18,174 @@ public AttributeGenerationTests() [Fact] public Task MapGetAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapGetAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapGetAttribute.gs.cs"); + } [Fact] public Task MapPostAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPostAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPostAttribute.gs.cs"); + } [Fact] public Task MapPutAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPutAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPutAttribute.gs.cs"); + } [Fact] public Task MapPatchAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPatchAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPatchAttribute.gs.cs"); + } [Fact] public Task MapDeleteAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapDeleteAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapDeleteAttribute.gs.cs"); + } [Fact] public Task MapOptionsAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapOptionsAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapOptionsAttribute.gs.cs"); + } [Fact] public Task MapHeadAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapHeadAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapHeadAttribute.gs.cs"); + } [Fact] public Task MapQueryAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapQueryAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapQueryAttribute.gs.cs"); + } [Fact] public Task MapTraceAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapTraceAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapTraceAttribute.gs.cs"); + } [Fact] public Task MapConnectAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapConnectAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapConnectAttribute.gs.cs"); + } [Fact] public Task MapFallbackAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapFallbackAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapFallbackAttribute.gs.cs"); + } [Fact] public Task RequireAuthorizationAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireAuthorizationAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireAuthorizationAttribute.gs.cs"); + } [Fact] public Task RequireCorsAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireCorsAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireCorsAttribute.gs.cs"); + } [Fact] public Task RequireRateLimitingAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireRateLimitingAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireRateLimitingAttribute.gs.cs"); + } [Fact] public Task RequireHostAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireHostAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireHostAttribute.gs.cs"); + } [Fact] public Task DisableAntiforgeryAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableAntiforgeryAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableAntiforgeryAttribute.gs.cs"); + } [Fact] public Task ShortCircuitAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ShortCircuitAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ShortCircuitAttribute.gs.cs"); + } [Fact] public Task DisableRequestTimeoutAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableRequestTimeoutAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableRequestTimeoutAttribute.gs.cs"); + } [Fact] public Task DisableValidationAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableValidationAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableValidationAttribute.gs.cs"); + } [Fact] public Task RequestTimeoutAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequestTimeoutAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequestTimeoutAttribute.gs.cs"); + } [Fact] public Task OrderAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.OrderAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.OrderAttribute.gs.cs"); + } [Fact] public Task MapGroupAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapGroupAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapGroupAttribute.gs.cs"); + } [Fact] public Task SummaryAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.SummaryAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.SummaryAttribute.gs.cs"); + } [Fact] public Task AcceptsAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.AcceptsAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.AcceptsAttribute.gs.cs"); + } [Fact] public Task EndpointFilterAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.EndpointFilterAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.EndpointFilterAttribute.gs.cs"); + } [Fact] public Task ProducesResponseAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesResponseAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesResponseAttribute.gs.cs"); + } [Fact] public Task ProducesProblemAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesProblemAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesProblemAttribute.gs.cs"); + } [Fact] public Task ProducesValidationProblemAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesValidationProblemAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesValidationProblemAttribute.gs.cs"); + } private static Task VerifyAttributeAsync(string fileName) - => GeneratorResult.VerifyAsync(fileName); + { + return GeneratorResult.VerifyAsync(fileName); + } } diff --git a/tests/GeneratedEndpoints.Tests/Common/ScenarioNamer.cs b/tests/GeneratedEndpoints.Tests/Common/ScenarioNamer.cs index db7d771..27900d2 100644 --- a/tests/GeneratedEndpoints.Tests/Common/ScenarioNamer.cs +++ b/tests/GeneratedEndpoints.Tests/Common/ScenarioNamer.cs @@ -26,15 +26,13 @@ public static string Create(string prefix, params (string Name, object? Value)[] private static string Sanitize(object? value) { if (value is null) - { return "None"; - } return value switch { bool b => b ? "On" : "Off", string s => s, - _ => value.ToString() ?? "Value" + _ => value.ToString() ?? "Value", }; } } diff --git a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs index 2b843d6..4fef5bd 100644 --- a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs +++ b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs @@ -53,29 +53,22 @@ public static string BuildAuthorizationMatrixSource( bool excludeFromDescription, string? mapGroupPattern = null, bool classDisableValidation = false, - bool methodDisableValidation = false) + bool methodDisableValidation = false + ) { var builder = new StringBuilder(); if (classAllowAnonymous) - { builder.AppendLine("[AllowAnonymous]"); - } if (classRequireAuthorization) - { builder.AppendLine("[RequireAuthorization(\"ClassPolicy\")]"); - } if (classTags) - { builder.AppendLine("[Tags(\"Class\", \"Matrix\")]"); - } if (!string.IsNullOrWhiteSpace(classHost)) - { builder.AppendLine($"[RequireHost(\"{classHost}\")]"); - } if (classRequireCors) { @@ -84,79 +77,55 @@ public static string BuildAuthorizationMatrixSource( } if (!string.IsNullOrWhiteSpace(groupName) && mapGroupPattern is null) - { mapGroupPattern = ""; - } if (mapGroupPattern is not null) { var mapGroupAttribute = new StringBuilder(); mapGroupAttribute.Append($"[MapGroup(\"{mapGroupPattern}\""); if (!string.IsNullOrWhiteSpace(groupName)) - { mapGroupAttribute.Append($", Name = \"{groupName}\""); - } mapGroupAttribute.Append(")]"); builder.AppendLine(mapGroupAttribute.ToString()); } if (applyShortCircuit) - { builder.AppendLine("[ShortCircuit]"); - } if (classDisableValidation) - { builder.AppendLine("[DisableValidation]"); - } if (applyRequestTimeout) { - var timeoutArgument = string.IsNullOrWhiteSpace(requestTimeoutPolicy) - ? "" - : $"(\"{requestTimeoutPolicy}\")"; + var timeoutArgument = string.IsNullOrWhiteSpace(requestTimeoutPolicy) ? "" : $"(\"{requestTimeoutPolicy}\")"; builder.AppendLine($"[RequestTimeout{timeoutArgument}]"); } if (disableRequestTimeout) - { builder.AppendLine("[DisableRequestTimeout]"); - } if (orderValue != 0) - { builder.AppendLine($"[Order({orderValue})]"); - } if (excludeFromDescription) - { builder.AppendLine("[ExcludeFromDescription]"); - } builder.AppendLine("internal sealed class AuthorizationMatrixEndpoints"); builder.AppendLine("{"); builder.AppendLine(" [MapGet(\"/matrix/{id:int}\", Name = \"GetMatrix\")]"); if (methodAllowAnonymous) - { builder.AppendLine(" [AllowAnonymous]"); - } if (methodRequireAuthorization) - { builder.AppendLine(" [RequireAuthorization(\"MethodPolicy\")]"); - } if (methodTags) - { builder.AppendLine(" [Tags(\"Method\", \"Matrix\")]"); - } if (!string.IsNullOrWhiteSpace(methodHost)) - { builder.AppendLine($" [RequireHost(\"{methodHost}\", \"contoso.com\")]"); - } if (methodRequireCors) { @@ -171,9 +140,7 @@ public static string BuildAuthorizationMatrixSource( } if (methodDisableValidation) - { builder.AppendLine(" [DisableValidation]"); - } builder.AppendLine(" public static Ok Handle(int id) => id >= 0 ? TypedResults.Ok() : TypedResults.Ok();"); @@ -196,52 +163,44 @@ public static string BuildConfigureAndFiltersSource( bool includeMethodLevelFilter, bool includeGenericFilter, bool configureRegistersFilter, - string metadataValue) + string metadataValue + ) { var builder = new StringBuilder(); builder.AppendLine("using Microsoft.AspNetCore.Builder;"); builder.AppendLine(); if (includeClassLevelFilter) - { builder.AppendLine("[EndpointFilter(typeof(TimingFilter))]"); - } builder.AppendLine("internal static class ConfigureFilterEndpoints"); builder.AppendLine("{"); builder.AppendLine(" [MapGet(\"/configure-filters\")]"); if (includeMethodLevelFilter) - { builder.AppendLine(" [EndpointFilter(typeof(ValidationFilter))]"); - } if (includeGenericFilter) - { builder.AppendLine(" [EndpointFilter]"); - } builder.AppendLine(" public static Ok Handle() => TypedResults.Ok();"); builder.AppendLine(); - builder.AppendLine(" public static void Configure(TBuilder builder" + (configureWithServiceProvider ? ", IServiceProvider services" : "") + ")"); + builder.AppendLine(" public static void Configure(TBuilder builder" + + (configureWithServiceProvider ? ", IServiceProvider services" : "") + + ")" + ); builder.AppendLine(" where TBuilder : IEndpointConventionBuilder"); builder.AppendLine(" {"); builder.AppendLine(" _ = builder;"); if (configureWithServiceProvider) - { builder.AppendLine(" _ = services;"); - } if (configureAddsMetadata) - { builder.AppendLine($" builder.WithMetadata(\"{metadataValue}\");"); - } if (configureRegistersFilter) - { builder.AppendLine(" builder.AddEndpointFilterFactory((context, next) => next);"); - } builder.AppendLine(" }"); builder.AppendLine("}"); @@ -270,7 +229,8 @@ public static string BuildHttpMethodMatrixSource( bool includeQuery, bool includeTrace, bool includeConnect, - bool includeMethodNameCollision) + bool includeMethodNameCollision + ) { var builder = new StringBuilder(); builder.AppendLine("using Microsoft.AspNetCore.Mvc;"); @@ -279,54 +239,36 @@ public static string BuildHttpMethodMatrixSource( builder.AppendLine("{"); if (includeGet) - { builder.AppendLine(" [MapGet(\"/matrix\")] public static Ok Get() => TypedResults.Ok();"); - } if (includePost) - { builder.AppendLine(" [MapPost(\"/matrix\")] public static Created Post() => TypedResults.Created(\"/matrix/1\", \"Created\");"); - } if (includePut) - { - builder.AppendLine(" [MapPut(\"/matrix/{id:int}\")] public static Results Put(int id) => id > 0 ? TypedResults.NoContent() : TypedResults.NotFound();"); - } + builder.AppendLine( + " [MapPut(\"/matrix/{id:int}\")] public static Results Put(int id) => id > 0 ? TypedResults.NoContent() : TypedResults.NotFound();" + ); if (includeDelete) - { builder.AppendLine(" [MapDelete(\"/matrix/{id:int}\")] public static IResult Delete(int id) => TypedResults.Ok();"); - } if (includeOptions) - { builder.AppendLine(" [MapOptions(\"/matrix\")] public static IResult Options() => TypedResults.Ok();"); - } if (includeHead) - { builder.AppendLine(" [MapHead(\"/matrix\")] public static IResult Head() => TypedResults.Ok();"); - } if (includePatch) - { builder.AppendLine(" [MapPatch(\"/matrix/{id:int}\")] public static IResult Patch(int id) => TypedResults.Ok();"); - } if (includeQuery) - { builder.AppendLine(" [MapQuery(\"/matrix/query\")] public static IResult Query([FromQuery] string value) => TypedResults.Ok(value);"); - } if (includeTrace) - { builder.AppendLine(" [MapTrace(\"/matrix\")] public static IResult Trace() => TypedResults.Ok();"); - } if (includeConnect) - { builder.AppendLine(" [MapConnect(\"/matrix\")] public static IResult Connect() => TypedResults.Ok();"); - } builder.AppendLine("}"); @@ -344,31 +286,33 @@ public static string BuildHttpMethodMatrixSource( } public static string BuildEndpointNameCollisionSource() - => """ - internal static class AlphaEndpoints - { - [MapGet("/alpha/collision")] public static Ok Collision() => TypedResults.Ok("alpha-collision"); - [MapGet("/alpha/unique")] public static Ok UniqueAlpha() => TypedResults.Ok("unique-alpha"); - } - - internal static class BetaEndpoints - { - [MapGet("/beta/unique")] public static Ok UniqueBeta() => TypedResults.Ok("unique-beta"); - [MapGet("/beta/collision")] public static Ok Collision() => TypedResults.Ok("beta-collision"); - } - - internal static class GammaEndpoints - { - [MapGet("/gamma/collision")] public static Ok Collision() => TypedResults.Ok("gamma-collision"); - [MapGet("/gamma/unique")] public static Ok UniqueGamma() => TypedResults.Ok("unique-gamma"); - } - - internal static class DeltaEndpoints - { - [MapGet("/delta/unique")] public static Ok UniqueDelta() => TypedResults.Ok("unique-delta"); - [MapGet("/delta/collision")] public static Ok Collision() => TypedResults.Ok("delta-collision"); - } - """; + { + return """ + internal static class AlphaEndpoints + { + [MapGet("/alpha/collision")] public static Ok Collision() => TypedResults.Ok("alpha-collision"); + [MapGet("/alpha/unique")] public static Ok UniqueAlpha() => TypedResults.Ok("unique-alpha"); + } + + internal static class BetaEndpoints + { + [MapGet("/beta/unique")] public static Ok UniqueBeta() => TypedResults.Ok("unique-beta"); + [MapGet("/beta/collision")] public static Ok Collision() => TypedResults.Ok("beta-collision"); + } + + internal static class GammaEndpoints + { + [MapGet("/gamma/collision")] public static Ok Collision() => TypedResults.Ok("gamma-collision"); + [MapGet("/gamma/unique")] public static Ok UniqueGamma() => TypedResults.Ok("unique-gamma"); + } + + internal static class DeltaEndpoints + { + [MapGet("/delta/unique")] public static Ok UniqueDelta() => TypedResults.Ok("unique-delta"); + [MapGet("/delta/collision")] public static Ok Collision() => TypedResults.Ok("delta-collision"); + } + """; + } public static string BuildContractsAndBindingSource( bool includeBindingNames, @@ -389,7 +333,8 @@ public static string BuildContractsAndBindingSource( string? acceptsContentType1, string? acceptsContentType2, string? producesContentType1, - string? producesContentType2) + string? producesContentType2 + ) { var builder = new StringBuilder(); builder.AppendLine("using Microsoft.AspNetCore.Mvc;"); @@ -405,29 +350,19 @@ public static string BuildContractsAndBindingSource( } if (includeDisplayName) - { builder.AppendLine(" [DisplayName(\"Contract endpoint\")]"); - } if (includeTags) - { builder.AppendLine(" [Tags(\"Contracts\", \"Bindings\")]"); - } if (excludeFromDescription) - { builder.AppendLine(" [ExcludeFromDescription]"); - } if (allowAnonymous) - { builder.AppendLine(" [AllowAnonymous]"); - } if (methodRequiresAuthorization) - { builder.AppendLine(" [RequireAuthorization(\"ContractsPolicy\")]"); - } builder.AppendLine(" [MapGet(\"/contracts/{id:int}\")]"); @@ -438,52 +373,36 @@ public static string BuildContractsAndBindingSource( } if (includeGenericAccepts) - { builder.AppendLine($" [Accepts(\"{acceptsContentType1 ?? "application/json"}\")]"); - } if (includeProducesResponse) { var secondProduces = string.IsNullOrWhiteSpace(producesContentType2) ? "" : $", \"{producesContentType2}\""; - builder.AppendLine($" [ProducesResponse(200, \"{producesContentType1 ?? "application/json"}\"{secondProduces}, ResponseType = typeof(ResponseRecord))]"); + builder.AppendLine( + $" [ProducesResponse(200, \"{producesContentType1 ?? "application/json"}\"{secondProduces}, ResponseType = typeof(ResponseRecord))]" + ); } if (includeProducesProblem) - { builder.AppendLine($" [ProducesProblem(500, \"{producesContentType1 ?? "application/problem+json"}\")]"); - } if (includeProducesValidationProblem) - { builder.AppendLine($" [ProducesValidationProblem(422, \"{producesContentType1 ?? "application/problem+json"}\")]"); - } builder.AppendLine(" public static async Task, NotFound>> Handle("); - builder.AppendLine(includeBindingNames - ? " [FromRoute(Name = \"route-id\")] int id," - : " [FromRoute] int id,"); - builder.AppendLine(includeBindingNames - ? " [FromQuery(Name = \"filter-term\")] string? filter," - : " [FromQuery] string? filter,"); - builder.AppendLine(includeBindingNames - ? " [FromHeader(Name = \"x-trace-id\")] string? traceId," - : " [FromHeader] string? traceId,"); + builder.AppendLine(includeBindingNames ? " [FromRoute(Name = \"route-id\")] int id," : " [FromRoute] int id,"); + builder.AppendLine(includeBindingNames ? " [FromQuery(Name = \"filter-term\")] string? filter," : " [FromQuery] string? filter,"); + builder.AppendLine(includeBindingNames ? " [FromHeader(Name = \"x-trace-id\")] string? traceId," : " [FromHeader] string? traceId,"); builder.AppendLine(" [FromBody] RequestRecord request,"); if (includeAsParameters) - { builder.AppendLine(" [AsParameters] AdditionalParameters parameters,"); - } if (includeFromServices) - { builder.AppendLine(" [FromServices] IServiceProvider services,"); - } if (includeFromKeyedServices) - { builder.AppendLine(" [FromKeyedServices(\"special\")] object keyed,"); - } builder.AppendLine(" CancellationToken cancellationToken)"); builder.AppendLine(" {"); @@ -497,9 +416,7 @@ public static string BuildContractsAndBindingSource( builder.AppendLine("internal sealed record ResponseRecord(int Value);"); if (includeAsParameters) - { builder.AppendLine("internal sealed record AdditionalParameters(string? Search, int? Page);"); - } return builder.ToString(); } diff --git a/tests/GeneratedEndpoints.Tests/Common/TestHelpers.cs b/tests/GeneratedEndpoints.Tests/Common/TestHelpers.cs index 8771b24..e4a4dfb 100644 --- a/tests/GeneratedEndpoints.Tests/Common/TestHelpers.cs +++ b/tests/GeneratedEndpoints.Tests/Common/TestHelpers.cs @@ -11,7 +11,8 @@ public static GeneratorDriverRunResult RunGenerator(IEnumerable sources) { var cSharpParseOptions = new CSharpParseOptions(LanguageVersion.CSharp11).WithPreprocessorSymbols("NET10_0_OR_GREATER"); var cSharpCompilationOptions = new CSharpCompilationOptions(OutputKind.NetModule).WithNullableContextOptions(NullableContextOptions.Enable); - var (_, result) = IncrementalGenerator.RunWithDiagnostics(sources, cSharpParseOptions, AspNet100.References.All, cSharpCompilationOptions); + var (_, result) = + IncrementalGenerator.RunWithDiagnostics(sources, cSharpParseOptions, AspNet100.References.All, cSharpCompilationOptions); return result; } diff --git a/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj b/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj index 9078773..1bd4374 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj +++ b/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj @@ -39,21 +39,21 @@ - - GeneratedSourceTests.cs - - - IndividualTests.cs - - - GeneratedSourceTests.cs - - - GeneratedSourceTests.cs - - - GeneratedSourceTests.cs - + + GeneratedSourceTests.cs + + + IndividualTests.cs + + + GeneratedSourceTests.cs + + + GeneratedSourceTests.cs + + + GeneratedSourceTests.cs + diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs index 430d05b..dd38cbb 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs @@ -21,11 +21,9 @@ public async Task MapFallbackScenarios(bool withNamespace, bool includeDefaultFa { var sources = TestHelpers.GetSources(SourceFactory.BuildFallbackSource(includeDefaultFallback, includeCustomFallback, customRoute), withNamespace); var result = TestHelpers.RunGenerator(sources); - var scenario = ScenarioNamer.Create(nameof(MapFallbackScenarios), - ("Namespace", withNamespace), - ("Default", includeDefaultFallback), - ("Custom", includeCustomFallback), - ("Route", customRoute ?? "default")); + var scenario = ScenarioNamer.Create(nameof(MapFallbackScenarios), ("Namespace", withNamespace), ("Default", includeDefaultFallback), + ("Custom", includeCustomFallback), ("Route", customRoute ?? "default") + ); await result.VerifyAsync("AddEndpointHandlers.g.cs") .UseMethodName($"{scenario}_AddEndpointHandlers"); @@ -35,11 +33,21 @@ await result.VerifyAsync("MapEndpointHandlers.g.cs") } [Theory] - [InlineData(true, true, false, true, false, true, true, "*.contoso.com", "api.contoso.com", true, "NamedCorsPolicy", false, null, true, "RatePolicy", true, true, "TimeoutPolicy", false, 5, "Reporting", true)] - [InlineData(false, false, true, false, true, false, true, null, "services.contoso.com", false, null, true, "MethodCors", true, null, false, false, null, true, -1, null, false)] - [InlineData(true, true, true, true, true, true, false, "*.example.com", null, true, null, true, null, false, null, false, true, null, true, 0, "Operations", true)] - [InlineData(false, false, false, true, false, true, false, null, "*.alt.com", false, "CorsDefault", false, null, false, null, true, false, null, false, 10, null, false)] - [InlineData(true, false, true, false, true, false, true, "api.alt.com", null, true, null, true, "MethodCors", true, "BurstPolicy", false, true, "TimeoutPolicy", true, -5, "Docs", true)] + [InlineData(true, true, false, true, false, true, true, "*.contoso.com", "api.contoso.com", true, "NamedCorsPolicy", false, null, true, "RatePolicy", true, + true, "TimeoutPolicy", false, 5, "Reporting", true + )] + [InlineData(false, false, true, false, true, false, true, null, "services.contoso.com", false, null, true, "MethodCors", true, null, false, false, null, + true, -1, null, false + )] + [InlineData(true, true, true, true, true, true, false, "*.example.com", null, true, null, true, null, false, null, false, true, null, true, 0, "Operations", + true + )] + [InlineData(false, false, false, true, false, true, false, null, "*.alt.com", false, "CorsDefault", false, null, false, null, true, false, null, false, 10, + null, false + )] + [InlineData(true, false, true, false, true, false, true, "api.alt.com", null, true, null, true, "MethodCors", true, "BurstPolicy", false, true, + "TimeoutPolicy", true, -5, "Docs", true + )] public async Task AuthorizationAndMetadataMatrix( bool withNamespace, bool classAllowAnonymous, @@ -62,52 +70,24 @@ public async Task AuthorizationAndMetadataMatrix( bool disableRequestTimeout, int orderValue, string? groupName, - bool excludeFromDescription) + bool excludeFromDescription + ) { - var source = SourceFactory.BuildAuthorizationMatrixSource( - classAllowAnonymous, - methodAllowAnonymous, - classRequireAuthorization, - methodRequireAuthorization, - classTags, - methodTags, - classHost, - methodHost, - classRequireCors, - classCorsPolicy, - methodRequireCors, - methodCorsPolicy, - requireRateLimiting, - rateLimitingPolicy, - applyShortCircuit, - applyRequestTimeout, - requestTimeoutPolicy, - disableRequestTimeout, - orderValue, - groupName, - excludeFromDescription); + var source = SourceFactory.BuildAuthorizationMatrixSource(classAllowAnonymous, methodAllowAnonymous, classRequireAuthorization, + methodRequireAuthorization, classTags, methodTags, classHost, methodHost, classRequireCors, classCorsPolicy, methodRequireCors, methodCorsPolicy, + requireRateLimiting, rateLimitingPolicy, applyShortCircuit, applyRequestTimeout, requestTimeoutPolicy, disableRequestTimeout, orderValue, groupName, + excludeFromDescription + ); var sources = TestHelpers.GetSources(source, withNamespace); var result = TestHelpers.RunGenerator(sources); - var scenario = ScenarioNamer.Create(nameof(AuthorizationAndMetadataMatrix), - ("Namespace", withNamespace), - ("ClassAnon", classAllowAnonymous), - ("MethodAnon", methodAllowAnonymous), - ("ClassAuth", classRequireAuthorization), - ("MethodAuth", methodRequireAuthorization), - ("ClassTags", classTags), - ("MethodTags", methodTags), - ("ClassHost", classHost ?? "none"), - ("MethodHost", methodHost ?? "none"), - ("ClassCors", classRequireCors), - ("MethodCors", methodRequireCors), - ("RateLimit", requireRateLimiting), - ("ShortCircuit", applyShortCircuit), - ("RequestTimeout", applyRequestTimeout), - ("DisableTimeout", disableRequestTimeout), - ("Order", orderValue), - ("Group", groupName ?? "none"), - ("Exclude", excludeFromDescription)); + var scenario = ScenarioNamer.Create(nameof(AuthorizationAndMetadataMatrix), ("Namespace", withNamespace), ("ClassAnon", classAllowAnonymous), + ("MethodAnon", methodAllowAnonymous), ("ClassAuth", classRequireAuthorization), ("MethodAuth", methodRequireAuthorization), + ("ClassTags", classTags), ("MethodTags", methodTags), ("ClassHost", classHost ?? "none"), ("MethodHost", methodHost ?? "none"), + ("ClassCors", classRequireCors), ("MethodCors", methodRequireCors), ("RateLimit", requireRateLimiting), ("ShortCircuit", applyShortCircuit), + ("RequestTimeout", applyRequestTimeout), ("DisableTimeout", disableRequestTimeout), ("Order", orderValue), ("Group", groupName ?? "none"), + ("Exclude", excludeFromDescription) + ); await result.VerifyAsync("AddEndpointHandlers.g.cs") .UseMethodName($"{scenario}_AddEndpointHandlers"); @@ -130,28 +110,19 @@ public async Task ConfigureAndFiltersMatrix( bool includeMethodLevelFilter, bool includeGenericFilter, bool configureRegistersFilter, - string metadataValue) + string metadataValue + ) { - var source = SourceFactory.BuildConfigureAndFiltersSource( - configureWithServiceProvider, - configureAddsMetadata, - includeClassLevelFilter, - includeMethodLevelFilter, - includeGenericFilter, - configureRegistersFilter, - metadataValue); + var source = SourceFactory.BuildConfigureAndFiltersSource(configureWithServiceProvider, configureAddsMetadata, includeClassLevelFilter, + includeMethodLevelFilter, includeGenericFilter, configureRegistersFilter, metadataValue + ); var sources = TestHelpers.GetSources(source, withNamespace); var result = TestHelpers.RunGenerator(sources); - var scenario = ScenarioNamer.Create(nameof(ConfigureAndFiltersMatrix), - ("Namespace", withNamespace), - ("SvcProvider", configureWithServiceProvider), - ("Metadata", configureAddsMetadata), - ("ClassFilter", includeClassLevelFilter), - ("MethodFilter", includeMethodLevelFilter), - ("GenericFilter", includeGenericFilter), - ("ConfigureFilter", configureRegistersFilter), - ("Value", metadataValue)); + var scenario = ScenarioNamer.Create(nameof(ConfigureAndFiltersMatrix), ("Namespace", withNamespace), ("SvcProvider", configureWithServiceProvider), + ("Metadata", configureAddsMetadata), ("ClassFilter", includeClassLevelFilter), ("MethodFilter", includeMethodLevelFilter), + ("GenericFilter", includeGenericFilter), ("ConfigureFilter", configureRegistersFilter), ("Value", metadataValue) + ); await result.VerifyAsync("AddEndpointHandlers.g.cs") .UseMethodName($"{scenario}_AddEndpointHandlers"); @@ -178,36 +149,19 @@ public async Task HttpMethodMatrix( bool includeQuery, bool includeTrace, bool includeConnect, - bool includeMethodNameCollision) + bool includeMethodNameCollision + ) { - var source = SourceFactory.BuildHttpMethodMatrixSource( - includeGet, - includePost, - includePut, - includeDelete, - includeOptions, - includeHead, - includePatch, - includeQuery, - includeTrace, - includeConnect, - includeMethodNameCollision); + var source = SourceFactory.BuildHttpMethodMatrixSource(includeGet, includePost, includePut, includeDelete, includeOptions, includeHead, includePatch, + includeQuery, includeTrace, includeConnect, includeMethodNameCollision + ); var sources = TestHelpers.GetSources(source, withNamespace); var result = TestHelpers.RunGenerator(sources); - var scenario = ScenarioNamer.Create(nameof(HttpMethodMatrix), - ("Namespace", withNamespace), - ("Get", includeGet), - ("Post", includePost), - ("Put", includePut), - ("Delete", includeDelete), - ("Options", includeOptions), - ("Head", includeHead), - ("Patch", includePatch), - ("Query", includeQuery), - ("Trace", includeTrace), - ("Connect", includeConnect), - ("Collision", includeMethodNameCollision)); + var scenario = ScenarioNamer.Create(nameof(HttpMethodMatrix), ("Namespace", withNamespace), ("Get", includeGet), ("Post", includePost), + ("Put", includePut), ("Delete", includeDelete), ("Options", includeOptions), ("Head", includeHead), ("Patch", includePatch), + ("Query", includeQuery), ("Trace", includeTrace), ("Connect", includeConnect), ("Collision", includeMethodNameCollision) + ); await result.VerifyAsync("AddEndpointHandlers.g.cs") .UseMethodName($"{scenario}_AddEndpointHandlers"); @@ -217,10 +171,16 @@ await result.VerifyAsync("MapEndpointHandlers.g.cs") } [Theory] - [InlineData(true, true, true, true, true, true, true, true, true, true, true, true, true, false, true, false, "application/xml", "text/xml", "application/json", "text/json")] - [InlineData(false, false, true, false, false, true, false, true, true, false, false, false, true, true, false, true, "application/custom", null, "application/problem+json", null)] + [InlineData(true, true, true, true, true, true, true, true, true, true, true, true, true, false, true, false, "application/xml", "text/xml", + "application/json", "text/json" + )] + [InlineData(false, false, true, false, false, true, false, true, true, false, false, false, true, true, false, true, "application/custom", null, + "application/problem+json", null + )] [InlineData(true, true, false, true, true, false, true, false, false, true, true, true, false, false, true, true, null, null, null, null)] - [InlineData(false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, "application/xml", null, "application/json", null)] + [InlineData(false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, "application/xml", null, + "application/json", null + )] [InlineData(true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, null, "text/plain", null, "text/plain")] public async Task ContractsAndBindingMatrix( bool withNamespace, @@ -242,48 +202,23 @@ public async Task ContractsAndBindingMatrix( string? acceptsContentType1, string? acceptsContentType2, string? producesContentType1, - string? producesContentType2) + string? producesContentType2 + ) { - var source = SourceFactory.BuildContractsAndBindingSource( - includeBindingNames, - includeAsParameters, - includeFromServices, - includeFromKeyedServices, - includeAccepts, - includeGenericAccepts, - includeProducesResponse, - includeProducesProblem, - includeProducesValidationProblem, - includeSummaryAndDescription, - includeDisplayName, - includeTags, - excludeFromDescription, - allowAnonymous, - methodRequiresAuthorization, - acceptsContentType1, - acceptsContentType2, - producesContentType1, - producesContentType2); + var source = SourceFactory.BuildContractsAndBindingSource(includeBindingNames, includeAsParameters, includeFromServices, includeFromKeyedServices, + includeAccepts, includeGenericAccepts, includeProducesResponse, includeProducesProblem, includeProducesValidationProblem, + includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, methodRequiresAuthorization, + acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 + ); var sources = TestHelpers.GetSources(source, withNamespace); var result = TestHelpers.RunGenerator(sources); - var scenario = ScenarioNamer.Create(nameof(ContractsAndBindingMatrix), - ("Namespace", withNamespace), - ("BindingNames", includeBindingNames), - ("AsParameters", includeAsParameters), - ("Services", includeFromServices), - ("KeyedServices", includeFromKeyedServices), - ("Accepts", includeAccepts), - ("GenericAccepts", includeGenericAccepts), - ("Produces", includeProducesResponse), - ("Problem", includeProducesProblem), - ("Validation", includeProducesValidationProblem), - ("Summary", includeSummaryAndDescription), - ("DisplayName", includeDisplayName), - ("Tags", includeTags), - ("Exclude", excludeFromDescription), - ("AllowAnon", allowAnonymous), - ("MethodAuth", methodRequiresAuthorization)); + var scenario = ScenarioNamer.Create(nameof(ContractsAndBindingMatrix), ("Namespace", withNamespace), ("BindingNames", includeBindingNames), + ("AsParameters", includeAsParameters), ("Services", includeFromServices), ("KeyedServices", includeFromKeyedServices), ("Accepts", includeAccepts), + ("GenericAccepts", includeGenericAccepts), ("Produces", includeProducesResponse), ("Problem", includeProducesProblem), + ("Validation", includeProducesValidationProblem), ("Summary", includeSummaryAndDescription), ("DisplayName", includeDisplayName), + ("Tags", includeTags), ("Exclude", excludeFromDescription), ("AllowAnon", allowAnonymous), ("MethodAuth", methodRequiresAuthorization) + ); await result.VerifyAsync("AddEndpointHandlers.g.cs") .UseMethodName($"{scenario}_AddEndpointHandlers"); diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.cs b/tests/GeneratedEndpoints.Tests/IndividualTests.cs index ea30eb9..be6ee72 100644 --- a/tests/GeneratedEndpoints.Tests/IndividualTests.cs +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.cs @@ -14,7 +14,7 @@ public IndividualTests() [Fact] public async Task DefaultFallbackOnly() { - var source = FallbackScenario(includeDefault: true); + var source = FallbackScenario(true); await VerifyIndividualAsync(source, nameof(DefaultFallbackOnly)); } @@ -28,7 +28,7 @@ public async Task CustomFallbackRoute() [Fact] public async Task ClassAllowAnonymous() { - var source = AuthorizationScenario(classAllowAnonymous: true); + var source = AuthorizationScenario(true); await VerifyIndividualAsync(source, nameof(ClassAllowAnonymous)); } @@ -175,21 +175,9 @@ public async Task OrderMetadata() [Fact] public async Task ClassMapGroup() { - var source = AuthorizationScenario( - classRequireAuthorization: true, - classTags: true, - classHost: "*.individual.com", - classRequireCors: true, - classCorsPolicy: "ClassCors", - applyShortCircuit: true, - applyRequestTimeout: true, - requestTimeoutPolicy: "ClassTimeout", - orderValue: 2, - groupName: "ClassGroup", - excludeFromDescription: true, - methodAllowAnonymous: true, - methodTags: true, - mapGroupPattern: "/individuals" + var source = AuthorizationScenario(classRequireAuthorization: true, classTags: true, classHost: "*.individual.com", classRequireCors: true, + classCorsPolicy: "ClassCors", applyShortCircuit: true, applyRequestTimeout: true, requestTimeoutPolicy: "ClassTimeout", orderValue: 2, + groupName: "ClassGroup", excludeFromDescription: true, methodAllowAnonymous: true, methodTags: true, mapGroupPattern: "/individuals" ); await VerifyIndividualAsync(source, nameof(ClassMapGroup)); } @@ -232,7 +220,7 @@ public async Task GenericEndpointFilter() [Fact] public async Task ConfigureWithServiceProvider() { - var source = ConfigureScenario(configureWithServiceProvider: true); + var source = ConfigureScenario(true); await VerifyIndividualAsync(source, nameof(ConfigureWithServiceProvider)); } @@ -253,7 +241,7 @@ public async Task ConfigureRegistersFilter() [Fact] public async Task MapGetEndpoint() { - var source = HttpMethodScenario(includeGet: true); + var source = HttpMethodScenario(true); await VerifyIndividualAsync(source, nameof(MapGetEndpoint)); } @@ -323,7 +311,7 @@ public async Task MapConnectEndpoint() [Fact] public async Task MethodNameCollision() { - var source = HttpMethodScenario(includeGet: true, includeMethodNameCollision: true); + var source = HttpMethodScenario(true, includeMethodNameCollision: true); await VerifyIndividualAsync(source, nameof(MethodNameCollision)); } @@ -337,7 +325,7 @@ public async Task MultipleEndpointNameCollisions() [Fact] public async Task BindingNames() { - var source = ContractScenario(includeBindingNames: true); + var source = ContractScenario(true); await VerifyIndividualAsync(source, nameof(BindingNames)); } @@ -473,7 +461,9 @@ await result.VerifyAsync("MapEndpointHandlers.g.cs") } private static string FallbackScenario(bool includeDefault = false, bool includeCustom = false, string? customRoute = null) - => SourceFactory.BuildFallbackSource(includeDefault, includeCustom, customRoute); + { + return SourceFactory.BuildFallbackSource(includeDefault, includeCustom, customRoute); + } private static string AuthorizationScenario( bool classAllowAnonymous = false, @@ -499,32 +489,15 @@ private static string AuthorizationScenario( bool excludeFromDescription = false, string? mapGroupPattern = null, bool classDisableValidation = false, - bool methodDisableValidation = false) - => SourceFactory.BuildAuthorizationMatrixSource( - classAllowAnonymous, - methodAllowAnonymous, - classRequireAuthorization, - methodRequireAuthorization, - classTags, - methodTags, - classHost, - methodHost, - classRequireCors, - classCorsPolicy, - methodRequireCors, - methodCorsPolicy, - requireRateLimiting, - rateLimitingPolicy, - applyShortCircuit, - applyRequestTimeout, - requestTimeoutPolicy, - disableRequestTimeout, - orderValue, - groupName, - excludeFromDescription, - mapGroupPattern, - classDisableValidation, - methodDisableValidation); + bool methodDisableValidation = false + ) + { + return SourceFactory.BuildAuthorizationMatrixSource(classAllowAnonymous, methodAllowAnonymous, classRequireAuthorization, methodRequireAuthorization, + classTags, methodTags, classHost, methodHost, classRequireCors, classCorsPolicy, methodRequireCors, methodCorsPolicy, requireRateLimiting, + rateLimitingPolicy, applyShortCircuit, applyRequestTimeout, requestTimeoutPolicy, disableRequestTimeout, orderValue, groupName, + excludeFromDescription, mapGroupPattern, classDisableValidation, methodDisableValidation + ); + } private static string ConfigureScenario( bool configureWithServiceProvider = false, @@ -533,15 +506,13 @@ private static string ConfigureScenario( bool includeMethodLevelFilter = false, bool includeGenericFilter = false, bool configureRegistersFilter = false, - string metadataValue = "Individual") - => SourceFactory.BuildConfigureAndFiltersSource( - configureWithServiceProvider, - configureAddsMetadata, - includeClassLevelFilter, - includeMethodLevelFilter, - includeGenericFilter, - configureRegistersFilter, - metadataValue); + string metadataValue = "Individual" + ) + { + return SourceFactory.BuildConfigureAndFiltersSource(configureWithServiceProvider, configureAddsMetadata, includeClassLevelFilter, + includeMethodLevelFilter, includeGenericFilter, configureRegistersFilter, metadataValue + ); + } private static string HttpMethodScenario( bool includeGet = false, @@ -554,22 +525,18 @@ private static string HttpMethodScenario( bool includeQuery = false, bool includeTrace = false, bool includeConnect = false, - bool includeMethodNameCollision = false) - => SourceFactory.BuildHttpMethodMatrixSource( - includeGet, - includePost, - includePut, - includeDelete, - includeOptions, - includeHead, - includePatch, - includeQuery, - includeTrace, - includeConnect, - includeMethodNameCollision); + bool includeMethodNameCollision = false + ) + { + return SourceFactory.BuildHttpMethodMatrixSource(includeGet, includePost, includePut, includeDelete, includeOptions, includeHead, includePatch, + includeQuery, includeTrace, includeConnect, includeMethodNameCollision + ); + } private static string EndpointNameCollisionScenario() - => SourceFactory.BuildEndpointNameCollisionSource(); + { + return SourceFactory.BuildEndpointNameCollisionSource(); + } private static string ContractScenario( bool includeBindingNames = false, @@ -590,59 +557,49 @@ private static string ContractScenario( string? acceptsContentType1 = null, string? acceptsContentType2 = null, string? producesContentType1 = null, - string? producesContentType2 = null) - => SourceFactory.BuildContractsAndBindingSource( - includeBindingNames, - includeAsParameters, - includeFromServices, - includeFromKeyedServices, - includeAccepts, - includeGenericAccepts, - includeProducesResponse, - includeProducesProblem, - includeProducesValidationProblem, - includeSummaryAndDescription, - includeDisplayName, - includeTags, - excludeFromDescription, - allowAnonymous, - methodRequiresAuthorization, - acceptsContentType1, - acceptsContentType2, - producesContentType1, - producesContentType2); + string? producesContentType2 = null + ) + { + return SourceFactory.BuildContractsAndBindingSource(includeBindingNames, includeAsParameters, includeFromServices, includeFromKeyedServices, + includeAccepts, includeGenericAccepts, includeProducesResponse, includeProducesProblem, includeProducesValidationProblem, + includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, methodRequiresAuthorization, + acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 + ); + } private static string AsyncHandlerScenario() - => """ - using System.Threading.Tasks; - - internal sealed class AsyncHandlerEndpoints - { - [MapGet("/task")] - public async Task TaskOnly() - { - await Task.Yield(); - } - - [MapGet("/task-result")] - public async Task, NotFound>> TaskWithResult(int id) - { - await Task.Yield(); - return id >= 0 ? TypedResults.Ok("task") : TypedResults.NotFound(); - } - - [MapPost("/valuetask")] - public async ValueTask ValueTaskOnly() - { - await Task.Yield(); - } - - [MapPost("/valuetask-result")] - public async ValueTask, NotFound>> ValueTaskWithResult(int id) - { - await Task.Yield(); - return id >= 0 ? TypedResults.Ok("value") : TypedResults.NotFound(); - } - } - """; + { + return """ + using System.Threading.Tasks; + + internal sealed class AsyncHandlerEndpoints + { + [MapGet("/task")] + public async Task TaskOnly() + { + await Task.Yield(); + } + + [MapGet("/task-result")] + public async Task, NotFound>> TaskWithResult(int id) + { + await Task.Yield(); + return id >= 0 ? TypedResults.Ok("task") : TypedResults.NotFound(); + } + + [MapPost("/valuetask")] + public async ValueTask ValueTaskOnly() + { + await Task.Yield(); + } + + [MapPost("/valuetask-result")] + public async ValueTask, NotFound>> ValueTaskWithResult(int id) + { + await Task.Yield(); + return id >= 0 ? TypedResults.Ok("value") : TypedResults.NotFound(); + } + } + """; + } }