diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index feb2322..e6c0677 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -15,45 +15,22 @@ public sealed class MinimalApiGenerator : IIncrementalGenerator private const string BaseNamespace = "Microsoft.AspNetCore.Generated"; private const string AttributesNamespace = $"{BaseNamespace}.Attributes"; - private const string MapGetAttributeName = "MapGetAttribute"; - private const string MapGetAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapGetAttributeName}"; - private const string MapGetAttributeHint = $"{MapGetAttributeFullyQualifiedName}.gs.cs"; - - private const string MapPostAttributeName = "MapPostAttribute"; - private const string MapPostAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapPostAttributeName}"; - private const string MapPostAttributeHint = $"{MapPostAttributeFullyQualifiedName}.gs.cs"; - - private const string MapPutAttributeName = "MapPutAttribute"; - private const string MapPutAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapPutAttributeName}"; - private const string MapPutAttributeHint = $"{MapPutAttributeFullyQualifiedName}.gs.cs"; - - private const string MapDeleteAttributeName = "MapDeleteAttribute"; - private const string MapDeleteAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapDeleteAttributeName}"; - private const string MapDeleteAttributeHint = $"{MapDeleteAttributeFullyQualifiedName}.gs.cs"; - - private const string MapOptionsAttributeName = "MapOptionsAttribute"; - private const string MapOptionsAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapOptionsAttributeName}"; - private const string MapOptionsAttributeHint = $"{MapOptionsAttributeFullyQualifiedName}.gs.cs"; - - private const string MapHeadAttributeName = "MapHeadAttribute"; - private const string MapHeadAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapHeadAttributeName}"; - private const string MapHeadAttributeHint = $"{MapHeadAttributeFullyQualifiedName}.gs.cs"; - - private const string MapPatchAttributeName = "MapPatchAttribute"; - private const string MapPatchAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapPatchAttributeName}"; - private const string MapPatchAttributeHint = $"{MapPatchAttributeFullyQualifiedName}.gs.cs"; - - private const string MapQueryAttributeName = "MapQueryAttribute"; - private const string MapQueryAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapQueryAttributeName}"; - private const string MapQueryAttributeHint = $"{MapQueryAttributeFullyQualifiedName}.gs.cs"; - - private const string MapTraceAttributeName = "MapTraceAttribute"; - private const string MapTraceAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapTraceAttributeName}"; - private const string MapTraceAttributeHint = $"{MapTraceAttributeFullyQualifiedName}.gs.cs"; - - private const string MapConnectAttributeName = "MapConnectAttribute"; - private const string MapConnectAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapConnectAttributeName}"; - private const string MapConnectAttributeHint = $"{MapConnectAttributeFullyQualifiedName}.gs.cs"; + private static readonly ImmutableArray HttpAttributeDefinitions = + [ + CreateHttpAttributeDefinition("MapGetAttribute", "GET"), + CreateHttpAttributeDefinition("MapPostAttribute", "POST"), + CreateHttpAttributeDefinition("MapPutAttribute", "PUT"), + CreateHttpAttributeDefinition("MapPatchAttribute", "PATCH"), + CreateHttpAttributeDefinition("MapDeleteAttribute", "DELETE"), + CreateHttpAttributeDefinition("MapOptionsAttribute", "OPTIONS"), + CreateHttpAttributeDefinition("MapHeadAttribute", "HEAD"), + CreateHttpAttributeDefinition("MapQueryAttribute", "QUERY"), + CreateHttpAttributeDefinition("MapTraceAttribute", "TRACE"), + CreateHttpAttributeDefinition("MapConnectAttribute", "CONNECT"), + ]; + + private static readonly ImmutableDictionary HttpAttributeDefinitionsByName = + HttpAttributeDefinitions.ToImmutableDictionary(static definition => definition.Name); private const string NameAttributeNamedParameter = "Name"; private const string SummaryAttributeNamedParameter = "Summary"; @@ -116,103 +93,55 @@ public sealed class MinimalApiGenerator : IIncrementalGenerator #nullable enable """; + private static HttpAttributeDefinition CreateHttpAttributeDefinition(string attributeName, string verb) + { + var fullyQualifiedName = $"{AttributesNamespace}.{attributeName}"; + return new HttpAttributeDefinition(attributeName, fullyQualifiedName, $"{fullyQualifiedName}.gs.cs", verb); + } + public void Initialize(IncrementalGeneratorInitializationContext context) { context.RegisterPostInitializationOutput(RegisterAttributes); - var getRequestHandlers = context.SyntaxProvider - .ForAttributeWithMetadataName(MapGetAttributeFullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - var postRequestHandlers = context.SyntaxProvider - .ForAttributeWithMetadataName(MapPostAttributeFullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - var putRequestHandlers = context.SyntaxProvider - .ForAttributeWithMetadataName(MapPutAttributeFullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - var deleteRequestHandlers = context.SyntaxProvider - .ForAttributeWithMetadataName(MapDeleteAttributeFullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - var optionsRequestHandlers = context.SyntaxProvider - .ForAttributeWithMetadataName(MapOptionsAttributeFullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - var headRequestHandlers = context.SyntaxProvider - .ForAttributeWithMetadataName(MapHeadAttributeFullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - var patchRequestHandlers = context.SyntaxProvider - .ForAttributeWithMetadataName(MapPatchAttributeFullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - var queryRequestHandlers = context.SyntaxProvider - .ForAttributeWithMetadataName(MapQueryAttributeFullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - var traceRequestHandlers = context.SyntaxProvider - .ForAttributeWithMetadataName(MapTraceAttributeFullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - var connectRequestHandlers = context.SyntaxProvider - .ForAttributeWithMetadataName(MapConnectAttributeFullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - var requestHandlers = getRequestHandlers.Combine(postRequestHandlers) - .Select(static (x, _) => x.Left.AddRange(x.Right)) - .Combine(putRequestHandlers) - .Select(static (x, _) => x.Left.AddRange(x.Right)) - .Combine(patchRequestHandlers) - .Select(static (x, _) => x.Left.AddRange(x.Right)) - .Combine(deleteRequestHandlers) - .Select(static (x, _) => x.Left.AddRange(x.Right)) - .Combine(optionsRequestHandlers) - .Select(static (x, _) => x.Left.AddRange(x.Right)) - .Combine(headRequestHandlers) - .Select(static (x, _) => x.Left.AddRange(x.Right)) - .Combine(queryRequestHandlers) - .Select(static (x, _) => x.Left.AddRange(x.Right)) - .Combine(traceRequestHandlers) - .Select(static (x, _) => x.Left.AddRange(x.Right)) - .Combine(connectRequestHandlers) - .Select(static (x, _) => x.Left.AddRange(x.Right)); + var requestHandlerProviders = ImmutableArray.CreateBuilder>>( + HttpAttributeDefinitions.Length); + + foreach (var definition in HttpAttributeDefinitions) + { + var handlers = context.SyntaxProvider + .ForAttributeWithMetadataName(definition.FullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) + .WhereNotNull() + .Collect(); + + requestHandlerProviders.Add(handlers); + } + + var requestHandlers = CombineRequestHandlers(requestHandlerProviders.MoveToImmutable()); context.RegisterSourceOutput(requestHandlers, GenerateSource); } - private static void RegisterAttributes(IncrementalGeneratorPostInitializationContext context) + private static IncrementalValueProvider> CombineRequestHandlers( + ImmutableArray>> handlerProviders) { - // Definitions for HTTP method attributes - var httpAttributes = new[] + if (handlerProviders.IsDefaultOrEmpty) + throw new InvalidOperationException("No HTTP attribute definitions were provided."); + + var combined = handlerProviders[0]; + for (var i = 1; i < handlerProviders.Length; i++) { - (Name: MapGetAttributeName, FullyQualified: MapGetAttributeFullyQualifiedName, Hint: MapGetAttributeHint, Verb: "GET"), - (Name: MapPostAttributeName, FullyQualified: MapPostAttributeFullyQualifiedName, Hint: MapPostAttributeHint, Verb: "POST"), - (Name: MapPutAttributeName, FullyQualified: MapPutAttributeFullyQualifiedName, Hint: MapPutAttributeHint, Verb: "PUT"), - (Name: MapDeleteAttributeName, FullyQualified: MapDeleteAttributeFullyQualifiedName, Hint: MapDeleteAttributeHint, Verb: "DELETE"), - (Name: MapOptionsAttributeName, FullyQualified: MapOptionsAttributeFullyQualifiedName, Hint: MapOptionsAttributeHint, Verb: "OPTIONS"), - (Name: MapHeadAttributeName, FullyQualified: MapHeadAttributeFullyQualifiedName, Hint: MapHeadAttributeHint, Verb: "HEAD"), - (Name: MapPatchAttributeName, FullyQualified: MapPatchAttributeFullyQualifiedName, Hint: MapPatchAttributeHint, Verb: "PATCH"), - (Name: MapQueryAttributeName, FullyQualified: MapQueryAttributeFullyQualifiedName, Hint: MapQueryAttributeHint, Verb: "QUERY"), - (Name: MapTraceAttributeName, FullyQualified: MapTraceAttributeFullyQualifiedName, Hint: MapTraceAttributeHint, Verb: "TRACE"), - (Name: MapConnectAttributeName, FullyQualified: MapConnectAttributeFullyQualifiedName, Hint: MapConnectAttributeHint, Verb: "CONNECT"), - }; + combined = combined.Combine(handlerProviders[i]).Select(static (x, _) => x.Left.AddRange(x.Right)); + } - foreach (var (name, _, hint, verb) in httpAttributes) + return combined; + } + + private static void RegisterAttributes(IncrementalGeneratorPostInitializationContext context) + { + foreach (var definition in HttpAttributeDefinitions) { - var source = GenerateHttpAttributeSource(FileHeader, AttributesNamespace, name, verb); - context.AddSource(hint, SourceText.From(source, Encoding.UTF8)); + var source = GenerateHttpAttributeSource(FileHeader, AttributesNamespace, definition.Name, definition.Verb); + context.AddSource(definition.Hint, SourceText.From(source, Encoding.UTF8)); } // RequireAuthorization @@ -646,20 +575,9 @@ CancellationToken cancellationToken var attributeName = attribute.AttributeClass?.Name ?? ""; - var httpMethod = attributeName switch - { - MapGetAttributeName => "Get", - MapPostAttributeName => "Post", - MapPutAttributeName => "Put", - MapDeleteAttributeName => "Delete", - MapOptionsAttributeName => "OPTIONS", - MapHeadAttributeName => "HEAD", - MapPatchAttributeName => "Patch", - MapQueryAttributeName => "QUERY", - MapTraceAttributeName => "TRACE", - MapConnectAttributeName => "CONNECT", - _ => "", - }; + var httpMethod = HttpAttributeDefinitionsByName.TryGetValue(attributeName, out var definition) + ? definition.Verb + : ""; var pattern = (attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : "") ?? ""; @@ -1432,13 +1350,15 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl source.AppendLine("("); } + var mapMethodSuffix = GetMapMethodSuffix(requestHandler.HttpMethod); + source.Append(indent); source.Append("builder.Map"); - source.Append(requestHandler.HttpMethod is "Get" or "Post" or "Put" or "Delete" or "Patch" ? requestHandler.HttpMethod : "Methods"); + source.Append(mapMethodSuffix ?? "Methods"); source.Append('('); source.Append(StringLiteral(requestHandler.Pattern)); source.Append(", "); - if (requestHandler.HttpMethod is "OPTIONS" or "HEAD" or "TRACE" or "CONNECT" or "QUERY") + if (mapMethodSuffix is null) { source.Append("new[] { \""); source.Append(requestHandler.HttpMethod); @@ -1638,6 +1558,19 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl } } + private static string? GetMapMethodSuffix(string httpMethod) + { + return httpMethod switch + { + "GET" => "Get", + "POST" => "Post", + "PUT" => "Put", + "DELETE" => "Delete", + "PATCH" => "Patch", + _ => null, + }; + } + private static string GetBindingSourceAttribute(BindingSource source, string? key) { return source switch @@ -1900,6 +1833,8 @@ _ when char.IsControl(c) => "\\u" + ((int)c).ToString("x4", CultureInfo.Invarian }; } + private readonly record struct HttpAttributeDefinition(string Name, string FullyQualifiedName, string Hint, string Verb); + private readonly record struct RequestHandler( RequestHandlerClass Class, RequestHandlerMethod Method,