diff --git a/README.md b/README.md index 1f6aa08..31fc6f4 100644 --- a/README.md +++ b/README.md @@ -140,41 +140,41 @@ applied to all endpoints within the class. ## Attribute Reference -| Definition | Usage | Description | -| --- | --- | --- | -| `[Accepts(string contentType = "application/json", params string[] additionalContentTypes, RequestType = null, IsOptional = false)]` | Method | Declares the accepted request body CLR type, optional status, and list of content types for the handler. | -| `[Accepts(string contentType = "application/json", params string[] additionalContentTypes, IsOptional = false)]` | Method | Generic shortcut for specifying the request type and accepted content types for the handler. | -| `[AllowAnonymous]` | Class or Method | Allows the annotated endpoint or class to bypass authorization requirements. | -| `[Description(string description)]` | Method | Sets the OpenAPI description metadata for the generated endpoint. | -| `[DisableAntiforgery]` | Class or Method | Disables antiforgery protection for the annotated endpoint(s). | -| `[DisableRequestTimeout]` | Class or Method | Disables request timeout enforcement for the annotated endpoint(s). | -| `[DisableValidation]` | Class or Method | Disables automatic request validation (when supported) for the annotated endpoint(s). | -| `[DisplayName(string displayName)]` | Method | Overrides the endpoint display name used in diagnostics and metadata. | -| `[EndpointFilter(Type filterType)]` | Class or Method | Adds the specified endpoint filter type to the handler pipeline. | -| `[EndpointFilter]` | Class or Method | Generic form for registering an endpoint filter type on the handler pipeline. | -| `[ExcludeFromDescription]` | Class or Method | Hides the endpoint or class from generated API descriptions (e.g., OpenAPI). | -| `[MapConnect(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP CONNECT endpoint using the supplied route pattern and optional name. | -| `[MapDelete(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP DELETE endpoint using the supplied route pattern and optional name. | -| `[MapFallback(string pattern = "", Name = null)]` | Method | Maps the method as the fallback endpoint invoked when no other route matches. | -| `[MapGroup(string pattern, Name = null)]` | Class | Assigns a route group pattern and optional endpoint group name to every handler in the class. | -| `[MapGet(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP GET endpoint using the supplied route pattern and optional name. | -| `[MapHead(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP HEAD endpoint using the supplied route pattern and optional name. | -| `[MapOptions(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP OPTIONS endpoint using the supplied route pattern and optional name. | -| `[MapPatch(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP PATCH endpoint using the supplied route pattern and optional name. | -| `[MapPost(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP POST endpoint using the supplied route pattern and optional name. | -| `[MapPut(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP PUT endpoint using the supplied route pattern and optional name. | -| `[MapQuery(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP QUERY endpoint using the supplied route pattern and optional name. | -| `[MapTrace(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP TRACE endpoint using the supplied route pattern and optional name. | -| `[Order(int order)]` | Method | Controls the order in which endpoint conventions are applied to the handler. | -| `[ProducesProblem(int statusCode = StatusCodes.Status500InternalServerError, string? contentType = null, params string[] additionalContentTypes)]` | Method | Declares that the endpoint emits a problem details payload for the given status code and content types. | -| `[ProducesResponse(int statusCode = StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes, ResponseType = null)]` | Method | Declares response metadata for the handler, including status code, optional CLR type, and content types. | -| `[ProducesResponse(int statusCode = StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes)]` | Method | Generic shorthand for declaring the CLR response type along with status code and content types. | -| `[ProducesValidationProblem(int statusCode = StatusCodes.Status400BadRequest, string? contentType = null, params string[] additionalContentTypes)]` | Method | Declares that the endpoint returns validation problem details for the specified status code and content types. | -| `[RequestTimeout(string? policyName = null)]` | Class or Method | Applies the default or a named request-timeout policy to the handler(s). | -| `[RequireAuthorization(params string[] policies)]` | Class or Method | Enforces authorization on the handler(s), optionally scoping access to specific policies. | -| `[RequireCors(string? policyName = null)]` | Class or Method | Requires the default or a named CORS policy for the annotated handler(s). | -| `[RequireHost(params string[] hosts)]` | Class or Method | Restricts the handler(s) to the specified allowed hostnames. | -| `[RequireRateLimiting(string policyName)]` | Class or Method | Enforces the named rate-limiting policy on the annotated handler(s). | -| `[ShortCircuit]` | Class or Method | Marks the handler(s) to short-circuit the request pipeline when invoked. | -| `[Summary(string summary)]` | Class or Method | Sets the summary metadata applied to the generated endpoint(s). | -| `[Tags(params string[] tags)]` | Class or Method | Assigns OpenAPI tags to the annotated handler(s) for grouping in API docs. | +| Definition | Usage | Description | +|---------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------|----------------------------------------------------------------------------------------------------------------| +| `[Accepts(string contentType = "application/json", params string[] additionalContentTypes, RequestType = null, IsOptional = false)]` | Method | Declares the accepted request body CLR type, optional status, and list of content types for the handler. | +| `[Accepts(string contentType = "application/json", params string[] additionalContentTypes, IsOptional = false)]` | Method | Generic shortcut for specifying the request type and accepted content types for the handler. | +| `[AllowAnonymous]` | Class or Method | Allows the annotated endpoint or class to bypass authorization requirements. | +| `[Description(string description)]` | Method | Sets the OpenAPI description metadata for the generated endpoint. | +| `[DisableAntiforgery]` | Class or Method | Disables antiforgery protection for the annotated endpoint(s). | +| `[DisableRequestTimeout]` | Class or Method | Disables request timeout enforcement for the annotated endpoint(s). | +| `[DisableValidation]` | Class or Method | Disables automatic request validation (when supported) for the annotated endpoint(s). | +| `[DisplayName(string displayName)]` | Method | Overrides the endpoint display name used in diagnostics and metadata. | +| `[EndpointFilter(Type filterType)]` | Class or Method | Adds the specified endpoint filter type to the handler pipeline. | +| `[EndpointFilter]` | Class or Method | Generic form for registering an endpoint filter type on the handler pipeline. | +| `[ExcludeFromDescription]` | Class or Method | Hides the endpoint or class from generated API descriptions (e.g., OpenAPI). | +| `[MapConnect(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP CONNECT endpoint using the supplied route pattern and optional name. | +| `[MapDelete(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP DELETE endpoint using the supplied route pattern and optional name. | +| `[MapFallback(string pattern = "", Name = null)]` | Method | Maps the method as the fallback endpoint invoked when no other route matches. | +| `[MapGroup(string pattern, Name = null)]` | Class | Assigns a route group pattern and optional endpoint group name to every handler in the class. | +| `[MapGet(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP GET endpoint using the supplied route pattern and optional name. | +| `[MapHead(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP HEAD endpoint using the supplied route pattern and optional name. | +| `[MapOptions(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP OPTIONS endpoint using the supplied route pattern and optional name. | +| `[MapPatch(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP PATCH endpoint using the supplied route pattern and optional name. | +| `[MapPost(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP POST endpoint using the supplied route pattern and optional name. | +| `[MapPut(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP PUT endpoint using the supplied route pattern and optional name. | +| `[MapQuery(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP QUERY endpoint using the supplied route pattern and optional name. | +| `[MapTrace(string pattern = "", Name = null)]` | Method | Marks the method as an HTTP TRACE endpoint using the supplied route pattern and optional name. | +| `[Order(int order)]` | Method | Controls the order in which endpoint conventions are applied to the handler. | +| `[ProducesProblem(int statusCode = StatusCodes.Status500InternalServerError, string? contentType = null, params string[] additionalContentTypes)]` | Method | Declares that the endpoint emits a problem details payload for the given status code and content types. | +| `[ProducesResponse(int statusCode = StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes, ResponseType = null)]` | Method | Declares response metadata for the handler, including status code, optional CLR type, and content types. | +| `[ProducesResponse(int statusCode = StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes)]` | Method | Generic shorthand for declaring the CLR response type along with status code and content types. | +| `[ProducesValidationProblem(int statusCode = StatusCodes.Status400BadRequest, string? contentType = null, params string[] additionalContentTypes)]` | Method | Declares that the endpoint returns validation problem details for the specified status code and content types. | +| `[RequestTimeout(string? policyName = null)]` | Class or Method | Applies the default or a named request-timeout policy to the handler(s). | +| `[RequireAuthorization(params string[] policies)]` | Class or Method | Enforces authorization on the handler(s), optionally scoping access to specific policies. | +| `[RequireCors(string? policyName = null)]` | Class or Method | Requires the default or a named CORS policy for the annotated handler(s). | +| `[RequireHost(params string[] hosts)]` | Class or Method | Restricts the handler(s) to the specified allowed hostnames. | +| `[RequireRateLimiting(string policyName)]` | Class or Method | Enforces the named rate-limiting policy on the annotated handler(s). | +| `[ShortCircuit]` | Class or Method | Marks the handler(s) to short-circuit the request pipeline when invoked. | +| `[Summary(string summary)]` | Class or Method | Sets the summary metadata applied to the generated endpoint(s). | +| `[Tags(params string[] tags)]` | Class or Method | Assigns OpenAPI tags to the annotated handler(s) for grouping in API docs. | diff --git a/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs new file mode 100644 index 0000000..2da92c7 --- /dev/null +++ b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs @@ -0,0 +1,105 @@ +using System.Text; +using GeneratedEndpoints.Common; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; +using static GeneratedEndpoints.Common.Constants; + +namespace GeneratedEndpoints; + +// ReSharper disable ForCanBeConvertedToForeach +// ReSharper disable LoopCanBeConvertedToQuery +// Do not refactor, use for loop to avoid allocations. + +internal static class AddEndpointHandlersGenerator +{ + public static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) + { + context.CancellationToken.ThrowIfCancellationRequested(); + + var nonStaticClassNames = GetDistinctNonStaticClassNames(requestHandlers); + var source = GetAddEndpointHandlersStringBuilder(nonStaticClassNames); + source.AppendLine(FileHeader); + + source.AppendLine(); + + source.AppendLine("using Microsoft.Extensions.DependencyInjection;"); + source.AppendLine("using Microsoft.Extensions.DependencyInjection.Extensions;"); + source.AppendLine(); + + source.Append("namespace "); + source.Append(RoutingNamespace); + source.AppendLine(";"); + + source.AppendLine(); + + source.Append("internal static class "); + source.Append(AddEndpointHandlersClassName); + source.AppendLine(); + + source.AppendLine("{"); + + source.Append(" internal static void "); + source.Append(AddEndpointHandlersMethodName); + source.AppendLine("(this IServiceCollection services)"); + + source.AppendLine(" {"); + + foreach (var className in nonStaticClassNames) + { + source.Append(" services.TryAddScoped<"); + source.Append(className); + source.Append(">();"); + source.AppendLine(); + } + + source.AppendLine(""" + } + } + """ + ); + + var sourceText = StringBuilderPool.ToStringAndReturn(source); + context.AddSource(AddEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); + } + + private static List GetDistinctNonStaticClassNames(EquatableImmutableArray requestHandlers) + { + var classNames = new List(); + if (requestHandlers.Count == 0) + return classNames; + + var seen = new HashSet(StringComparer.Ordinal); + for (var index = 0; index < requestHandlers.Count; index++) + { + var requestHandler = requestHandlers[index]; + if (requestHandler.Class.IsStatic) + continue; + + var className = requestHandler.Class.Name; + if (seen.Add(className)) + classNames.Add(className); + } + + return classNames; + } + + private static StringBuilder GetAddEndpointHandlersStringBuilder(List nonStaticClassNames) + { + var estimate = 512L; + for (var index = 0; index < nonStaticClassNames.Count; index++) + { + var className = nonStaticClassNames[index]; + estimate += 36 + className.Length; + } + + estimate += Math.Max(256, nonStaticClassNames.Count * 12); + estimate = (long)(estimate * 1.10); + + if (estimate < 512) + estimate = 512; + else if (estimate > int.MaxValue) + estimate = int.MaxValue; + + return StringBuilderPool.Get((int)estimate); + } +} diff --git a/src/GeneratedEndpoints/Common/AcceptsMetadata.cs b/src/GeneratedEndpoints/Common/AcceptsMetadata.cs new file mode 100644 index 0000000..3656580 --- /dev/null +++ b/src/GeneratedEndpoints/Common/AcceptsMetadata.cs @@ -0,0 +1,8 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct AcceptsMetadata( + string RequestType, + string ContentType, + EquatableImmutableArray? AdditionalContentTypes, + bool IsOptional +); diff --git a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs new file mode 100644 index 0000000..1a10e01 --- /dev/null +++ b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs @@ -0,0 +1,91 @@ +using Microsoft.CodeAnalysis; + +namespace GeneratedEndpoints.Common; + +internal static class AttributeDataExtensions +{ + public static string? GetConstructorStringValue(this AttributeData attribute, int position = 0) + { + if (attribute.ConstructorArguments.Length > position) + return (attribute.ConstructorArguments[position].Value as string).NormalizeOptionalString(); + + return null; + } + + public static EquatableImmutableArray? GetConstructorStringArray(this AttributeData attribute, int position = 0) + { + if (attribute.ConstructorArguments.Length <= position) + return null; + + var arg = attribute.ConstructorArguments[position]; + if (arg.Kind == TypedConstantKind.Array) + { + if (arg.Values.Length == 0) + return null; + + List? normalized = null; + foreach (var value in arg.Values) + { + if (value.Value is not string stringValue) + continue; + + var trimmed = stringValue.NormalizeOptionalString(); + if (trimmed is not { Length: > 0 }) + continue; + + normalized ??= new List(arg.Values.Length); + normalized.Add(trimmed); + } + + if (normalized is { Count: > 0 }) + return normalized.ToEquatableImmutableArray(); + } + else if (arg.Value is string singleHost && !string.IsNullOrWhiteSpace(singleHost)) + { + return new[] { singleHost.Trim() }.ToEquatableImmutableArray(); + } + + return null; + } + + public static int? GetConstructorIntValue(this AttributeData attribute, int position = 0) + { + if (attribute.ConstructorArguments.Length > position && attribute.ConstructorArguments[position].Value is int value) + return value; + + return null; + } + + public static ITypeSymbol? GetNamedTypeSymbol(this AttributeData attribute, string namedParameter) + { + foreach (var namedArg in attribute.NamedArguments) + { + if (namedArg.Key == namedParameter && namedArg.Value.Value is ITypeSymbol typeSymbol) + return typeSymbol; + } + + return null; + } + + public static bool GetNamedBoolValue(this AttributeData attribute, string namedParameter, bool defaultValue = false) + { + foreach (var namedArg in attribute.NamedArguments) + { + if (namedArg.Key == namedParameter && namedArg.Value.Value is bool boolValue) + return boolValue; + } + + return defaultValue; + } + + public static string? GetNamedStringValue(this AttributeData attribute, string namedParameter) + { + foreach (var namedArg in attribute.NamedArguments) + { + if (namedArg.Key == namedParameter && namedArg.Value.Value is string stringValue) + return stringValue.NormalizeOptionalString(); + } + + return null; + } +} diff --git a/src/GeneratedEndpoints/Common/AttributeSymbolMatcher.cs b/src/GeneratedEndpoints/Common/AttributeSymbolMatcher.cs new file mode 100644 index 0000000..fc5dc15 --- /dev/null +++ b/src/GeneratedEndpoints/Common/AttributeSymbolMatcher.cs @@ -0,0 +1,25 @@ +using Microsoft.CodeAnalysis; + +namespace GeneratedEndpoints.Common; + +internal static class AttributeSymbolMatcher +{ + public static bool IsAttribute(INamedTypeSymbol attributeClass, string attributeName, string[] namespaceParts) + { + var definition = attributeClass.OriginalDefinition; + return definition.Name == attributeName && IsInNamespace(definition.ContainingNamespace, namespaceParts); + } + + public static bool IsInNamespace(INamespaceSymbol? namespaceSymbol, string[] namespaceParts) + { + for (var i = namespaceParts.Length - 1; i >= 0; i--) + { + if (namespaceSymbol is null || namespaceSymbol.Name != namespaceParts[i]) + return false; + + namespaceSymbol = namespaceSymbol.ContainingNamespace; + } + + return namespaceSymbol is null || namespaceSymbol.IsGlobalNamespace; + } +} diff --git a/src/GeneratedEndpoints/Common/BindingSource.cs b/src/GeneratedEndpoints/Common/BindingSource.cs new file mode 100644 index 0000000..4db1144 --- /dev/null +++ b/src/GeneratedEndpoints/Common/BindingSource.cs @@ -0,0 +1,14 @@ +namespace GeneratedEndpoints.Common; + +internal enum BindingSource +{ + None = 0, + FromRoute = 1, + FromQuery = 2, + FromHeader = 3, + FromBody = 4, + FromForm = 5, + FromServices = 6, + FromKeyedServices = 7, + AsParameters = 8, +} diff --git a/src/GeneratedEndpoints/Common/ConfigureMethodDetails.cs b/src/GeneratedEndpoints/Common/ConfigureMethodDetails.cs new file mode 100644 index 0000000..0a9d741 --- /dev/null +++ b/src/GeneratedEndpoints/Common/ConfigureMethodDetails.cs @@ -0,0 +1,3 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct ConfigureMethodDetails(bool HasConfigureMethod, bool ConfigureMethodAcceptsServiceProvider); diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.Constants.cs b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs similarity index 52% rename from src/GeneratedEndpoints/MinimalApiGenerator.Constants.cs rename to src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs index 0f51419..6b6d8e5 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.Constants.cs +++ b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs @@ -1,136 +1,26 @@ using System.Collections.Immutable; -using System.Runtime.CompilerServices; using System.Text; -using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Text; -namespace GeneratedEndpoints; +namespace GeneratedEndpoints.Common; -public sealed partial class MinimalApiGenerator +internal static partial class Constants { - private const string BaseNamespace = "Microsoft.AspNetCore.Generated"; - private const string AttributesNamespace = $"{BaseNamespace}.Attributes"; - - private const string FallbackHttpMethod = "__FALLBACK__"; - - private const string NameAttributeNamedParameter = "Name"; - private const string ResponseTypeAttributeNamedParameter = "ResponseType"; - private const string RequestTypeAttributeNamedParameter = "RequestType"; - private const string IsOptionalAttributeNamedParameter = "IsOptional"; - private const string PolicyNameAttributeNamedParameter = "PolicyName"; - - private const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute"; - private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}"; - private const string RequireAuthorizationAttributeHint = $"{RequireAuthorizationAttributeFullyQualifiedName}.gs.cs"; - - private const string RequireCorsAttributeName = "RequireCorsAttribute"; - private const string RequireCorsAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireCorsAttributeName}"; - private const string RequireCorsAttributeHint = $"{RequireCorsAttributeFullyQualifiedName}.gs.cs"; - - private const string RequireRateLimitingAttributeName = "RequireRateLimitingAttribute"; - private const string RequireRateLimitingAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireRateLimitingAttributeName}"; - private const string RequireRateLimitingAttributeHint = $"{RequireRateLimitingAttributeFullyQualifiedName}.gs.cs"; - - private const string RequireHostAttributeName = "RequireHostAttribute"; - private const string RequireHostAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireHostAttributeName}"; - private const string RequireHostAttributeHint = $"{RequireHostAttributeFullyQualifiedName}.gs.cs"; - - private const string DisableAntiforgeryAttributeName = "DisableAntiforgeryAttribute"; - private const string DisableAntiforgeryAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableAntiforgeryAttributeName}"; - private const string DisableAntiforgeryAttributeHint = $"{DisableAntiforgeryAttributeFullyQualifiedName}.gs.cs"; - - private const string ShortCircuitAttributeName = "ShortCircuitAttribute"; - private const string ShortCircuitAttributeFullyQualifiedName = $"{AttributesNamespace}.{ShortCircuitAttributeName}"; - private const string ShortCircuitAttributeHint = $"{ShortCircuitAttributeFullyQualifiedName}.gs.cs"; - - private const string DisableRequestTimeoutAttributeName = "DisableRequestTimeoutAttribute"; - private const string DisableRequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableRequestTimeoutAttributeName}"; - private const string DisableRequestTimeoutAttributeHint = $"{DisableRequestTimeoutAttributeFullyQualifiedName}.gs.cs"; - - private const string DisableValidationAttributeName = "DisableValidationAttribute"; - private const string DisableValidationAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableValidationAttributeName}"; - private const string DisableValidationAttributeHint = $"{DisableValidationAttributeFullyQualifiedName}.gs.cs"; - - private const string RequestTimeoutAttributeName = "RequestTimeoutAttribute"; - private const string RequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequestTimeoutAttributeName}"; - private const string RequestTimeoutAttributeHint = $"{RequestTimeoutAttributeFullyQualifiedName}.gs.cs"; - - private const string OrderAttributeName = "OrderAttribute"; - private const string OrderAttributeFullyQualifiedName = $"{AttributesNamespace}.{OrderAttributeName}"; - private const string OrderAttributeHint = $"{OrderAttributeFullyQualifiedName}.gs.cs"; - - private const string MapGroupAttributeName = "MapGroupAttribute"; - private const string MapGroupAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapGroupAttributeName}"; - private const string MapGroupAttributeHint = $"{MapGroupAttributeFullyQualifiedName}.gs.cs"; - - private const string SummaryAttributeName = "SummaryAttribute"; - private const string SummaryAttributeFullyQualifiedName = $"{AttributesNamespace}.{SummaryAttributeName}"; - private const string SummaryAttributeHint = $"{SummaryAttributeFullyQualifiedName}.gs.cs"; - - private const string AllowAnonymousAttributeName = "AllowAnonymousAttribute"; - - private const string EndpointFilterAttributeName = "EndpointFilterAttribute"; - private const string EndpointFilterAttributeFullyQualifiedName = $"{AttributesNamespace}.{EndpointFilterAttributeName}"; - private const string EndpointFilterAttributeHint = $"{EndpointFilterAttributeFullyQualifiedName}.gs.cs"; - - private const string AcceptsAttributeName = "AcceptsAttribute"; - private const string AcceptsAttributeFullyQualifiedName = $"{AttributesNamespace}.{AcceptsAttributeName}"; - private const string AcceptsAttributeHint = $"{AcceptsAttributeFullyQualifiedName}.gs.cs"; - - private const string ProducesResponseAttributeName = "ProducesResponseAttribute"; - private const string ProducesResponseAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesResponseAttributeName}"; - private const string ProducesResponseAttributeHint = $"{ProducesResponseAttributeFullyQualifiedName}.gs.cs"; - - private const string ProducesProblemAttributeName = "ProducesProblemAttribute"; - private const string ProducesProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesProblemAttributeName}"; - private const string ProducesProblemAttributeHint = $"{ProducesProblemAttributeFullyQualifiedName}.gs.cs"; - - private const string ProducesValidationProblemAttributeName = "ProducesValidationProblemAttribute"; - - private const string ProducesValidationProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesValidationProblemAttributeName}"; - - private const string ProducesValidationProblemAttributeHint = $"{ProducesValidationProblemAttributeFullyQualifiedName}.gs.cs"; - - private const string RoutingNamespace = $"{BaseNamespace}.Routing"; - - private const string AddEndpointHandlersClassName = "EndpointServicesExtensions"; - private const string AddEndpointHandlersMethodName = "AddEndpointHandlers"; - private const string AddEndpointHandlersMethodHint = $"{RoutingNamespace}.{AddEndpointHandlersMethodName}.g.cs"; - - private const string UseEndpointHandlersClassName = "EndpointRouteBuilderExtensions"; - private const string UseEndpointHandlersMethodName = "MapEndpointHandlers"; - private const string UseEndpointHandlersMethodHint = $"{RoutingNamespace}.{UseEndpointHandlersMethodName}.g.cs"; - - private const string ConfigureMethodName = "Configure"; - private const string AsyncSuffix = "Async"; - private const string GlobalPrefix = "global::"; - private static readonly string[] AttributesNamespaceParts = AttributesNamespace.Split('.'); - private static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"]; - private static readonly string[] AspNetCoreMvcNamespaceParts = ["Microsoft", "AspNetCore", "Mvc"]; - private static readonly string[] AspNetCoreAuthorizationNamespaceParts = ["Microsoft", "AspNetCore", "Authorization"]; - private static readonly string[] AspNetCoreRoutingNamespaceParts = ["Microsoft", "AspNetCore", "Routing"]; - private static readonly string[] ExtensionsDependencyInjectionNamespaceParts = - ["Microsoft", "Extensions", "DependencyInjection"]; - private static readonly string[] ComponentModelNamespaceParts = ["System", "ComponentModel"]; - private static readonly ConditionalWeakTable CompilationTypeCaches = new(); - private static readonly ConditionalWeakTable RequestHandlerClassCache = new(); - private static readonly ConditionalWeakTable GeneratedAttributeKindCache = new(); - - private static readonly string FileHeader = $""" - //----------------------------------------------------------------------------- - // - // This code was generated by {nameof(MinimalApiGenerator)} which can be found - // in the {typeof(MinimalApiGenerator).Namespace} namespace. - // - // Changes to this file may cause incorrect behavior - // and will be lost if the code is regenerated. - // - //----------------------------------------------------------------------------- - - #nullable enable - """; - - private static readonly ImmutableArray HttpAttributeDefinitions = + internal static readonly string FileHeader = $""" + //----------------------------------------------------------------------------- + // + // This code was generated by {nameof(MinimalApiGenerator)} which can be found + // in the {typeof(MinimalApiGenerator).Namespace} namespace. + // + // Changes to this file may cause incorrect behavior + // and will be lost if the code is regenerated. + // + //----------------------------------------------------------------------------- + + #nullable enable + """; + + internal static readonly ImmutableArray HttpAttributeDefinitions = [ CreateHttpAttributeDefinition("MapGetAttribute", "GET"), CreateHttpAttributeDefinition("MapPostAttribute", "POST"), @@ -145,513 +35,556 @@ public sealed partial class MinimalApiGenerator CreateHttpAttributeDefinition("MapFallbackAttribute", FallbackHttpMethod, true), ]; - private static readonly ImmutableDictionary HttpAttributeDefinitionsByName = + internal static readonly ImmutableDictionary HttpAttributeDefinitionsByName = HttpAttributeDefinitions.ToImmutableDictionary(static definition => definition.Name); - private static readonly SourceText RequireAuthorizationAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText RequireAuthorizationAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; + + /// + /// Specifies that authorization is required for the annotated endpoint or class. + /// Optionally restricts access to the specified authorization policies. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{RequireAuthorizationAttributeName}} : global::System.Attribute + { + /// + /// Gets the policy names that the endpoint or class requires. + /// + public string[] PolicyNames { get; } + + /// + /// Marks the endpoint or class as requiring authorization. + /// + public {{RequireAuthorizationAttributeName}}() + { + PolicyNames = []; + } + + /// + /// Marks the endpoint or class as requiring authorization with one or more policies. + /// + public {{RequireAuthorizationAttributeName}}(params string[] policyNames) + { + PolicyNames = policyNames ?? []; + } + } + """, Encoding.UTF8 + ); + + internal static readonly SourceText RequireCorsAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; + + /// + /// Specifies that the annotated endpoint requires a configured CORS policy. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{RequireCorsAttributeName}} : global::System.Attribute + { + /// + /// Gets the optional CORS policy name. + /// + public string? PolicyName { get; } + + /// + /// Marks the endpoint or class as requiring the default CORS policy. + /// + public {{RequireCorsAttributeName}}() + { + } + + /// + /// Marks the endpoint or class as requiring the specified named CORS policy. + /// + public {{RequireCorsAttributeName}}(string policyName) + { + PolicyName = policyName; + } + } + """, Encoding.UTF8 + ); + + internal static readonly SourceText RequireRateLimitingAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; /// - /// Specifies that authorization is required for the annotated endpoint or class. - /// Optionally restricts access to the specified authorization policies. + /// Specifies that the annotated endpoint requires the provided rate limiting policy. /// [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{RequireAuthorizationAttributeName}} : global::System.Attribute + internal sealed class {{RequireRateLimitingAttributeName}} : global::System.Attribute { /// - /// Gets the policy names that the endpoint or class requires. - /// - public string[] PolicyNames { get; } - - /// - /// Marks the endpoint or class as requiring authorization. + /// Initializes a new instance of the class. /// - public {{RequireAuthorizationAttributeName}}() + /// The rate limiting policy to apply. + public {{RequireRateLimitingAttributeName}}(string policyName) { - PolicyNames = []; + PolicyName = policyName; } /// - /// Marks the endpoint or class as requiring authorization with one or more policies. + /// Gets the rate limiting policy name. /// - public {{RequireAuthorizationAttributeName}}(params string[] policyNames) - { - PolicyNames = policyNames ?? []; - } + public string PolicyName { get; } } """, Encoding.UTF8 ); - private static readonly SourceText RequireCorsAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; + internal static readonly SourceText RequireHostAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - /// - /// Specifies that the annotated endpoint requires a configured CORS policy. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{RequireCorsAttributeName}} : global::System.Attribute - { - /// - /// Gets the optional CORS policy name. - /// - public string? PolicyName { get; } + namespace {{AttributesNamespace}}; - /// - /// Marks the endpoint or class as requiring the default CORS policy. - /// - public {{RequireCorsAttributeName}}() - { - } + /// + /// Specifies the allowed hosts for the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{RequireHostAttributeName}} : global::System.Attribute + { + /// + /// Initializes a new instance of the class. + /// + /// The hosts that are allowed to access the endpoint. + public {{RequireHostAttributeName}}(params string[] hosts) + { + Hosts = hosts ?? []; + } - /// - /// Marks the endpoint or class as requiring the specified named CORS policy. - /// - public {{RequireCorsAttributeName}}(string policyName) - { - PolicyName = policyName; - } - } - """, Encoding.UTF8 + /// + /// Gets the allowed hosts. + /// + public string[] Hosts { get; } + } + """, Encoding.UTF8 ); - private static readonly SourceText RequireRateLimitingAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText DisableAntiforgeryAttributeSourceText = SourceText.From($$""" {{FileHeader}} namespace {{AttributesNamespace}}; /// - /// Specifies that the annotated endpoint requires the provided rate limiting policy. + /// Disables antiforgery protection for the annotated endpoint or class. /// [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{RequireRateLimitingAttributeName}} : global::System.Attribute + internal sealed class {{DisableAntiforgeryAttributeName}} : global::System.Attribute { - /// - /// Initializes a new instance of the class. - /// - /// The rate limiting policy to apply. - public {{RequireRateLimitingAttributeName}}(string policyName) - { - PolicyName = policyName; - } - - /// - /// Gets the rate limiting policy name. - /// - public string PolicyName { get; } } + """, Encoding.UTF8 ); - private static readonly SourceText RequireHostAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + internal static readonly SourceText ShortCircuitAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - namespace {{AttributesNamespace}}; + namespace {{AttributesNamespace}}; - /// - /// Specifies the allowed hosts for the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{RequireHostAttributeName}} : global::System.Attribute - { - /// - /// Initializes a new instance of the class. - /// - /// The hosts that are allowed to access the endpoint. - public {{RequireHostAttributeName}}(params string[] hosts) - { - Hosts = hosts ?? []; - } + /// + /// Marks the annotated endpoint or class to short-circuit the request pipeline. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{ShortCircuitAttributeName}} : global::System.Attribute + { + } - /// - /// Gets the allowed hosts. - /// - public string[] Hosts { get; } - } - """, Encoding.UTF8 + """, Encoding.UTF8 ); - private static readonly SourceText DisableAntiforgeryAttributeSourceText = SourceText.From($$""" + internal static readonly SourceText DisableRequestTimeoutAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; + + /// + /// Disables the request timeout for the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{DisableRequestTimeoutAttributeName}} : global::System.Attribute + { + } + + """, Encoding.UTF8 + ); + + internal static readonly SourceText DisableValidationAttributeSourceText = SourceText.From($$""" + #if NET10_0_OR_GREATER {{FileHeader}} namespace {{AttributesNamespace}}; /// - /// Disables antiforgery protection for the annotated endpoint or class. + /// Disables request validation for the annotated endpoint or class when targeting .NET 10 or later. /// [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{DisableAntiforgeryAttributeName}} : global::System.Attribute + internal sealed class {{DisableValidationAttributeName}} : global::System.Attribute { } + #endif """, Encoding.UTF8 ); - private static readonly SourceText ShortCircuitAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; + internal static readonly SourceText RequestTimeoutAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - /// - /// Marks the annotated endpoint or class to short-circuit the request pipeline. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{ShortCircuitAttributeName}} : global::System.Attribute - { - } + namespace {{AttributesNamespace}}; - """, Encoding.UTF8 - ); + /// + /// Applies the request timeout metadata to the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{RequestTimeoutAttributeName}} : global::System.Attribute + { + /// + /// Gets the optional request timeout policy name. + /// + public string? PolicyName { get; init; } - private static readonly SourceText DisableRequestTimeoutAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + /// + /// Applies the default request timeout behavior. + /// + public {{RequestTimeoutAttributeName}}() + { + } - namespace {{AttributesNamespace}}; + /// + /// Applies the specified request timeout policy. + /// + /// The request timeout policy name. + public {{RequestTimeoutAttributeName}}(string policyName) + { + PolicyName = policyName; + } + } - /// - /// Disables the request timeout for the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{DisableRequestTimeoutAttributeName}} : global::System.Attribute - { - } + """, Encoding.UTF8 + ); - """, Encoding.UTF8 + internal static readonly SourceText OrderAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; + + /// + /// Specifies the order for the annotated endpoint when building conventions. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{OrderAttributeName}} : global::System.Attribute + { + /// + /// Gets the order that will be applied to the endpoint. + /// + public int Order { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The order value to apply to the endpoint. + public {{OrderAttributeName}}(int order) + { + Order = order; + } + } + + """, Encoding.UTF8 ); - private static readonly SourceText DisableValidationAttributeSourceText = SourceText.From($$""" - #if NET10_0_OR_GREATER - {{FileHeader}} + internal static readonly SourceText MapGroupAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - namespace {{AttributesNamespace}}; + namespace {{AttributesNamespace}}; - /// - /// Disables request validation for the annotated endpoint or class when targeting .NET 10 or later. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{DisableValidationAttributeName}} : global::System.Attribute - { - } - #endif + /// + /// Specifies the route group for the annotated class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class, Inherited = false, AllowMultiple = false)] + internal sealed class {{MapGroupAttributeName}} : global::System.Attribute + { + /// + /// Gets the route group pattern. + /// + public string Pattern { get; } + + /// + /// Gets or sets the endpoint group name. + /// + public string? Name { get; init; } + + /// + /// Initializes a new instance of the class. + /// + /// The route group pattern to apply. + public {{MapGroupAttributeName}}(string pattern) + { + Pattern = pattern; + } + } - """, Encoding.UTF8 + """, Encoding.UTF8 ); - private static readonly SourceText RequestTimeoutAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; + internal static readonly SourceText SummaryAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - /// - /// Applies the request timeout metadata to the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{RequestTimeoutAttributeName}} : global::System.Attribute - { - /// - /// Gets the optional request timeout policy name. - /// - public string? PolicyName { get; init; } + namespace {{AttributesNamespace}}; - /// - /// Applies the default request timeout behavior. - /// - public {{RequestTimeoutAttributeName}}() - { - } + /// + /// Specifies the summary metadata for the annotated endpoint. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{SummaryAttributeName}} : global::System.Attribute + { + /// + /// Gets the summary value for the endpoint. + /// + public string Summary { get; } - /// - /// Applies the specified request timeout policy. - /// - /// The request timeout policy name. - public {{RequestTimeoutAttributeName}}(string policyName) - { - PolicyName = policyName; - } - } + /// + /// Initializes a new instance of the class. + /// + /// The summary to apply to the endpoint. + public {{SummaryAttributeName}}(string summary) + { + Summary = summary; + } + } - """, Encoding.UTF8 + """, Encoding.UTF8 ); - private static readonly SourceText OrderAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + internal static readonly SourceText AcceptsAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - namespace {{AttributesNamespace}}; + namespace {{AttributesNamespace}}; - /// - /// Specifies the order for the annotated endpoint when building conventions. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{OrderAttributeName}} : global::System.Attribute - { - /// - /// Gets the order that will be applied to the endpoint. - /// - public int Order { get; } + /// + /// Specifies the request type and content types accepted by the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{AcceptsAttributeName}} : global::System.Attribute + { + /// + /// Gets the request type accepted by the endpoint. + /// + public global::System.Type? RequestType { get; init; } - /// - /// Initializes a new instance of the class. - /// - /// The order value to apply to the endpoint. - public {{OrderAttributeName}}(int order) - { - Order = order; - } - } + /// + /// Gets a value indicating whether the request body is optional. + /// + public bool IsOptional { get; init; } - """, Encoding.UTF8 - ); + /// + /// Gets the primary content type accepted by the endpoint. + /// + public string ContentType { get; } - private static readonly SourceText MapGroupAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + /// + /// Gets the additional content types accepted by the endpoint. + /// + public string[] AdditionalContentTypes { get; } - namespace {{AttributesNamespace}}; + /// + /// Initializes a new instance of the class. + /// + /// The primary content type accepted by the endpoint. + /// Additional content types accepted by the endpoint. + public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) + { + ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; + AdditionalContentTypes = additionalContentTypes ?? []; + } + } /// - /// Specifies the route group for the annotated class. + /// Specifies the request type using a generic argument and the content types accepted by the annotated endpoint or class. /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class, Inherited = false, AllowMultiple = false)] - internal sealed class {{MapGroupAttributeName}} : global::System.Attribute + /// The CLR type of the request body. + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{AcceptsAttributeName}} : global::System.Attribute { /// - /// Gets the route group pattern. + /// Gets the request type accepted by the endpoint. + /// + public global::System.Type RequestType => typeof(TRequest); + + /// + /// Gets a value indicating whether the request body is optional. /// - public string Pattern { get; } + public bool IsOptional { get; init; } /// - /// Gets or sets the endpoint group name. + /// Gets the primary content type accepted by the endpoint. /// - public string? Name { get; init; } + public string ContentType { get; } /// - /// Initializes a new instance of the class. + /// Gets the additional content types accepted by the endpoint. /// - /// The route group pattern to apply. - public {{MapGroupAttributeName}}(string pattern) + public string[] AdditionalContentTypes { get; } + + /// + /// Initializes a new instance of the generic Accepts attribute class. + /// + /// The primary content type accepted by the endpoint. + /// Additional content types accepted by the endpoint. + public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) { - Pattern = pattern; + ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; + AdditionalContentTypes = additionalContentTypes ?? []; } } """, Encoding.UTF8 ); - private static readonly SourceText SummaryAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; - - /// - /// Specifies the summary metadata for the annotated endpoint. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{SummaryAttributeName}} : global::System.Attribute - { - /// - /// Gets the summary value for the endpoint. - /// - public string Summary { get; } - - /// - /// Initializes a new instance of the class. - /// - /// The summary to apply to the endpoint. - public {{SummaryAttributeName}}(string summary) - { - Summary = summary; - } - } - - """, Encoding.UTF8 - ); + internal static readonly SourceText EndpointFilterAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; + + /// + /// Specifies an endpoint filter type to apply to the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{EndpointFilterAttributeName}} : global::System.Attribute + { + /// + /// Gets the CLR type of the endpoint filter. + /// + public global::System.Type FilterType { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The CLR type of the endpoint filter. + public {{EndpointFilterAttributeName}}(global::System.Type filterType) + { + FilterType = filterType ?? throw new global::System.ArgumentNullException(nameof(filterType)); + } + } - private static readonly SourceText AcceptsAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; - - /// - /// Specifies the request type and content types accepted by the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{AcceptsAttributeName}} : global::System.Attribute - { - /// - /// Gets the request type accepted by the endpoint. - /// - public global::System.Type? RequestType { get; init; } - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } - - /// - /// Gets the primary content type accepted by the endpoint. - /// - public string ContentType { get; } - - /// - /// Gets the additional content types accepted by the endpoint. - /// - public string[] AdditionalContentTypes { get; } - - /// - /// Initializes a new instance of the class. - /// - /// The primary content type accepted by the endpoint. - /// Additional content types accepted by the endpoint. - public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) - { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; - } - } - - /// - /// Specifies the request type using a generic argument and the content types accepted by the annotated endpoint or class. - /// - /// The CLR type of the request body. - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{AcceptsAttributeName}} : global::System.Attribute - { - /// - /// Gets the request type accepted by the endpoint. - /// - public global::System.Type RequestType => typeof(TRequest); - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } - - /// - /// Gets the primary content type accepted by the endpoint. - /// - public string ContentType { get; } - - /// - /// Gets the additional content types accepted by the endpoint. - /// - public string[] AdditionalContentTypes { get; } - - /// - /// Initializes a new instance of the generic Accepts attribute class. - /// - /// The primary content type accepted by the endpoint. - /// Additional content types accepted by the endpoint. - public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) - { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; - } - } - - """, Encoding.UTF8 + /// + /// Specifies an endpoint filter type using a generic argument. + /// + /// The CLR type of the endpoint filter. + [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{EndpointFilterAttributeName}} : global::System.Attribute + { + /// + /// Gets the CLR type of the endpoint filter. + /// + public global::System.Type FilterType => typeof(TFilter); + } + + """, Encoding.UTF8 ); - private static readonly SourceText EndpointFilterAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + internal static readonly SourceText ProducesResponseAttributeSourceText = SourceText.From($$""" + {{FileHeader}} - namespace {{AttributesNamespace}}; + namespace {{AttributesNamespace}}; - /// - /// Specifies an endpoint filter type to apply to the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{EndpointFilterAttributeName}} : global::System.Attribute - { - /// - /// Gets the CLR type of the endpoint filter. - /// - public global::System.Type FilterType { get; } + /// + /// Specifies a response type, status code, and content types produced by the annotated endpoint or class. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute + { + /// + /// Gets the response type produced by the endpoint. + /// + public global::System.Type? ResponseType { get; init; } - /// - /// Initializes a new instance of the class. - /// - /// The CLR type of the endpoint filter. - public {{EndpointFilterAttributeName}}(global::System.Type filterType) - { - FilterType = filterType ?? throw new global::System.ArgumentNullException(nameof(filterType)); - } - } - - /// - /// Specifies an endpoint filter type using a generic argument. - /// - /// The CLR type of the endpoint filter. - [global::System.AttributeUsage(global::System.AttributeTargets.Class | global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{EndpointFilterAttributeName}} : global::System.Attribute - { - /// - /// Gets the CLR type of the endpoint filter. - /// - public global::System.Type FilterType => typeof(TFilter); - } + /// + /// Gets the HTTP status code returned by the endpoint. + /// + public int StatusCode { get; } - """, Encoding.UTF8 - ); + /// + /// Gets the primary content type produced by the endpoint. + /// + public string? ContentType { get; } - private static readonly SourceText ProducesResponseAttributeSourceText = SourceText.From($$""" - {{FileHeader}} + /// + /// Gets the additional content types produced by the endpoint. + /// + public string[] AdditionalContentTypes { get; } - namespace {{AttributesNamespace}}; + /// + /// Initializes a new instance of the class. + /// + /// The HTTP status code returned by the endpoint. + /// The primary content type produced by the endpoint. + /// Additional content types produced by the endpoint. + public {{ProducesResponseAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) + { + StatusCode = statusCode; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes ?? []; + } + } - /// - /// Specifies a response type, status code, and content types produced by the annotated endpoint or class. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute - { - /// - /// Gets the response type produced by the endpoint. - /// - public global::System.Type? ResponseType { get; init; } + /// + /// Specifies a response type using a generic argument along with status code and content types produced by the annotated endpoint or class. + /// + /// The CLR type of the response body. + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute + { + /// + /// Gets the response type produced by the endpoint. + /// + public global::System.Type ResponseType => typeof(TResponse); - /// - /// Gets the HTTP status code returned by the endpoint. - /// - public int StatusCode { get; } + /// + /// Gets the HTTP status code returned by the endpoint. + /// + public int StatusCode { get; } - /// - /// Gets the primary content type produced by the endpoint. - /// - public string? ContentType { get; } + /// + /// Gets the primary content type produced by the endpoint. + /// + public string? ContentType { get; } - /// - /// Gets the additional content types produced by the endpoint. - /// - public string[] AdditionalContentTypes { get; } + /// + /// Gets the additional content types produced by the endpoint. + /// + public string[] AdditionalContentTypes { get; } - /// - /// Initializes a new instance of the class. - /// - /// The HTTP status code returned by the endpoint. - /// The primary content type produced by the endpoint. - /// Additional content types produced by the endpoint. - public {{ProducesResponseAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) - { - StatusCode = statusCode; - ContentType = contentType; - AdditionalContentTypes = additionalContentTypes ?? []; - } - } + /// + /// Initializes a new instance of the generic Produces attribute class. + /// + /// The HTTP status code returned by the endpoint. + /// The primary content type produced by the endpoint. + /// Additional content types produced by the endpoint. + public {{ProducesResponseAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) + { + StatusCode = statusCode; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes ?? []; + } + } + + """, Encoding.UTF8 + ); + + internal static readonly SourceText ProducesProblemAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; /// - /// Specifies a response type using a generic argument along with status code and content types produced by the annotated endpoint or class. + /// Specifies that the endpoint produces a problem details payload. /// - /// The CLR type of the response body. [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute + internal sealed class {{ProducesProblemAttributeName}} : global::System.Attribute { - /// - /// Gets the response type produced by the endpoint. - /// - public global::System.Type ResponseType => typeof(TResponse); - /// /// Gets the HTTP status code returned by the endpoint. /// @@ -668,12 +601,12 @@ internal sealed class {{ProducesResponseAttributeName}} : global::Sys public string[] AdditionalContentTypes { get; } /// - /// Initializes a new instance of the generic Produces attribute class. + /// Initializes a new instance of the class. /// /// The HTTP status code returned by the endpoint. /// The primary content type produced by the endpoint. /// Additional content types produced by the endpoint. - public {{ProducesResponseAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) + public {{ProducesProblemAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status500InternalServerError, string? contentType = null, params string[] additionalContentTypes) { StatusCode = statusCode; ContentType = contentType; @@ -684,90 +617,90 @@ internal sealed class {{ProducesResponseAttributeName}} : global::Sys """, Encoding.UTF8 ); - private static readonly SourceText ProducesProblemAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; - - /// - /// Specifies that the endpoint produces a problem details payload. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{ProducesProblemAttributeName}} : global::System.Attribute - { - /// - /// Gets the HTTP status code returned by the endpoint. - /// - public int StatusCode { get; } - - /// - /// Gets the primary content type produced by the endpoint. - /// - public string? ContentType { get; } - - /// - /// Gets the additional content types produced by the endpoint. - /// - public string[] AdditionalContentTypes { get; } - - /// - /// Initializes a new instance of the class. - /// - /// The HTTP status code returned by the endpoint. - /// The primary content type produced by the endpoint. - /// Additional content types produced by the endpoint. - public {{ProducesProblemAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status500InternalServerError, string? contentType = null, params string[] additionalContentTypes) - { - StatusCode = statusCode; - ContentType = contentType; - AdditionalContentTypes = additionalContentTypes ?? []; - } - } - - """, Encoding.UTF8 - ); - - private static readonly SourceText ProducesValidationProblemAttributeSourceText = SourceText.From($$""" - {{FileHeader}} - - namespace {{AttributesNamespace}}; - - /// - /// Specifies that the endpoint produces a validation problem details payload. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] - internal sealed class {{ProducesValidationProblemAttributeName}} : global::System.Attribute - { - /// - /// Gets the HTTP status code returned by the endpoint. - /// - public int StatusCode { get; } - - /// - /// Gets the primary content type produced by the endpoint. - /// - public string? ContentType { get; } - - /// - /// Gets the additional content types produced by the endpoint. - /// - public string[] AdditionalContentTypes { get; } - - /// - /// Initializes a new instance of the class. - /// - /// The HTTP status code returned by the endpoint. - /// The primary content type produced by the endpoint. - /// Additional content types produced by the endpoint. - public {{ProducesValidationProblemAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status400BadRequest, string? contentType = null, params string[] additionalContentTypes) - { - StatusCode = statusCode; - ContentType = contentType; - AdditionalContentTypes = additionalContentTypes ?? []; - } - } - - """, Encoding.UTF8 + internal static readonly SourceText ProducesValidationProblemAttributeSourceText = SourceText.From($$""" + {{FileHeader}} + + namespace {{AttributesNamespace}}; + + /// + /// Specifies that the endpoint produces a validation problem details payload. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] + internal sealed class {{ProducesValidationProblemAttributeName}} : global::System.Attribute + { + /// + /// Gets the HTTP status code returned by the endpoint. + /// + public int StatusCode { get; } + + /// + /// Gets the primary content type produced by the endpoint. + /// + public string? ContentType { get; } + + /// + /// Gets the additional content types produced by the endpoint. + /// + public string[] AdditionalContentTypes { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The HTTP status code returned by the endpoint. + /// The primary content type produced by the endpoint. + /// Additional content types produced by the endpoint. + public {{ProducesValidationProblemAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status400BadRequest, string? contentType = null, params string[] additionalContentTypes) + { + StatusCode = statusCode; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes ?? []; + } + } + + """, Encoding.UTF8 ); + private static HttpAttributeDefinition CreateHttpAttributeDefinition(string attributeName, string verb, bool allowOptionalPattern = false) + { + var fullyQualifiedName = $"{AttributesNamespace}.{attributeName}"; + var hint = $"{fullyQualifiedName}.gs.cs"; + var summaryVerb = verb == FallbackHttpMethod ? "fallback" : verb; + var source = GenerateHttpAttributeSource(AttributesNamespace, attributeName, summaryVerb, allowOptionalPattern); + return new HttpAttributeDefinition(attributeName, fullyQualifiedName, hint, verb, SourceText.From(source, Encoding.UTF8)); + } + + private static string GenerateHttpAttributeSource(string attributesNamespace, string attributeName, string summaryVerb, bool allowOptionalPattern = false) + { + return $$""" + {{FileHeader}} + + namespace {{attributesNamespace}}; + + /// + /// Identifies a method as an HTTP {{summaryVerb}} minimal API endpoint with the specified route pattern. + /// + [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] + internal sealed class {{attributeName}} : global::System.Attribute + { + /// + /// Gets the route pattern for the endpoint. + /// + public string{{(allowOptionalPattern ? "?" : "")}} Pattern { get; } + + /// + /// Gets or sets the endpoint name. + /// + public string? Name { get; init; } + + /// + /// Initializes a new instance of the class. + /// + /// The route pattern for the endpoint. + public {{attributeName}}([global::System.Diagnostics.CodeAnalysis.StringSyntax("Route")] string{{(allowOptionalPattern ? "?" : "")}} pattern{{(allowOptionalPattern ? " = null" : "")}}) + { + Pattern = pattern; + } + } + """; + } } diff --git a/src/GeneratedEndpoints/Common/Constants.cs b/src/GeneratedEndpoints/Common/Constants.cs new file mode 100644 index 0000000..8ef77d1 --- /dev/null +++ b/src/GeneratedEndpoints/Common/Constants.cs @@ -0,0 +1,115 @@ +using System.ComponentModel; + +namespace GeneratedEndpoints.Common; + +internal static partial class Constants +{ + internal const string FallbackHttpMethod = "__FALLBACK__"; + + internal const string NameAttributeNamedParameter = "Name"; + internal const string ResponseTypeAttributeNamedParameter = "ResponseType"; + internal const string RequestTypeAttributeNamedParameter = "RequestType"; + internal const string IsOptionalAttributeNamedParameter = "IsOptional"; + + internal const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute"; + internal const string RequireAuthorizationAttributeHint = $"{RequireAuthorizationAttributeFullyQualifiedName}.gs.cs"; + + internal const string RequireCorsAttributeName = "RequireCorsAttribute"; + internal const string RequireCorsAttributeHint = $"{RequireCorsAttributeFullyQualifiedName}.gs.cs"; + + internal const string RequireRateLimitingAttributeName = "RequireRateLimitingAttribute"; + internal const string RequireRateLimitingAttributeHint = $"{RequireRateLimitingAttributeFullyQualifiedName}.gs.cs"; + + internal const string RequireHostAttributeName = "RequireHostAttribute"; + internal const string RequireHostAttributeHint = $"{RequireHostAttributeFullyQualifiedName}.gs.cs"; + + internal const string DisableAntiforgeryAttributeName = "DisableAntiforgeryAttribute"; + internal const string DisableAntiforgeryAttributeHint = $"{DisableAntiforgeryAttributeFullyQualifiedName}.gs.cs"; + + internal const string ShortCircuitAttributeName = "ShortCircuitAttribute"; + internal const string ShortCircuitAttributeHint = $"{ShortCircuitAttributeFullyQualifiedName}.gs.cs"; + + internal const string DisableRequestTimeoutAttributeName = "DisableRequestTimeoutAttribute"; + internal const string DisableRequestTimeoutAttributeHint = $"{DisableRequestTimeoutAttributeFullyQualifiedName}.gs.cs"; + + internal const string DisableValidationAttributeName = "DisableValidationAttribute"; + internal const string DisableValidationAttributeHint = $"{DisableValidationAttributeFullyQualifiedName}.gs.cs"; + + internal const string RequestTimeoutAttributeName = "RequestTimeoutAttribute"; + internal const string RequestTimeoutAttributeHint = $"{RequestTimeoutAttributeFullyQualifiedName}.gs.cs"; + + internal const string OrderAttributeName = "OrderAttribute"; + internal const string OrderAttributeHint = $"{OrderAttributeFullyQualifiedName}.gs.cs"; + + internal const string MapGroupAttributeName = "MapGroupAttribute"; + internal const string MapGroupAttributeHint = $"{MapGroupAttributeFullyQualifiedName}.gs.cs"; + + internal const string SummaryAttributeName = "SummaryAttribute"; + internal const string SummaryAttributeHint = $"{SummaryAttributeFullyQualifiedName}.gs.cs"; + + internal const string DisplayNameAttributeName = nameof(DisplayNameAttribute); + internal const string DescriptionAttributeName = nameof(DescriptionAttribute); + internal const string AllowAnonymousAttributeName = "AllowAnonymousAttribute"; + internal const string TagsAttributeName = "TagsAttribute"; + internal const string ExcludeFromDescriptionAttributeName = "ExcludeFromDescriptionAttribute"; + + internal const string EndpointFilterAttributeName = "EndpointFilterAttribute"; + internal const string EndpointFilterAttributeHint = $"{EndpointFilterAttributeFullyQualifiedName}.gs.cs"; + + internal const string AcceptsAttributeName = "AcceptsAttribute"; + internal const string AcceptsAttributeHint = $"{AcceptsAttributeFullyQualifiedName}.gs.cs"; + + internal const string ProducesResponseAttributeName = "ProducesResponseAttribute"; + internal const string ProducesResponseAttributeHint = $"{ProducesResponseAttributeFullyQualifiedName}.gs.cs"; + + internal const string ProducesProblemAttributeName = "ProducesProblemAttribute"; + internal const string ProducesProblemAttributeHint = $"{ProducesProblemAttributeFullyQualifiedName}.gs.cs"; + + internal const string ProducesValidationProblemAttributeName = "ProducesValidationProblemAttribute"; + internal const string ProducesValidationProblemAttributeHint = $"{ProducesValidationProblemAttributeFullyQualifiedName}.gs.cs"; + + internal const string RoutingNamespace = $"{BaseNamespace}.Routing"; + + internal const string AddEndpointHandlersClassName = "EndpointServicesExtensions"; + internal const string AddEndpointHandlersMethodName = "AddEndpointHandlers"; + internal const string AddEndpointHandlersMethodHint = $"{AddEndpointHandlersMethodFullyQualifiedName}.g.cs"; + + internal const string UseEndpointHandlersClassName = "EndpointRouteBuilderExtensions"; + internal const string UseEndpointHandlersMethodName = "MapEndpointHandlers"; + internal const string UseEndpointHandlersMethodHint = $"{UseEndpointHandlersMethodFullyQualifiedName}.g.cs"; + + internal const string ConfigureMethodName = "Configure"; + internal const string AsyncSuffix = "Async"; + internal const string ApplicationJsonContentType = "application/json"; + internal const string GlobalPrefix = "global::"; + internal const string Dot = "."; + + internal static readonly string[] AttributesNamespaceParts = AttributesNamespace.Split('.'); + internal static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"]; + internal static readonly string[] AspNetCoreMvcNamespaceParts = ["Microsoft", "AspNetCore", "Mvc"]; + internal static readonly string[] AspNetCoreAuthorizationNamespaceParts = ["Microsoft", "AspNetCore", "Authorization"]; + internal static readonly string[] AspNetCoreRoutingNamespaceParts = ["Microsoft", "AspNetCore", "Routing"]; + internal static readonly string[] ExtensionsDependencyInjectionNamespaceParts = ["Microsoft", "Extensions", "DependencyInjection"]; + internal static readonly string[] ComponentModelNamespaceParts = ["System", "ComponentModel"]; + private const string BaseNamespace = "Microsoft.AspNetCore.Generated"; + private const string AttributesNamespace = $"{BaseNamespace}.Attributes"; + private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}"; + private const string RequireCorsAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireCorsAttributeName}"; + private const string RequireRateLimitingAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireRateLimitingAttributeName}"; + private const string RequireHostAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireHostAttributeName}"; + private const string DisableAntiforgeryAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableAntiforgeryAttributeName}"; + private const string ShortCircuitAttributeFullyQualifiedName = $"{AttributesNamespace}.{ShortCircuitAttributeName}"; + private const string DisableRequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableRequestTimeoutAttributeName}"; + private const string DisableValidationAttributeFullyQualifiedName = $"{AttributesNamespace}.{DisableValidationAttributeName}"; + private const string RequestTimeoutAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequestTimeoutAttributeName}"; + private const string OrderAttributeFullyQualifiedName = $"{AttributesNamespace}.{OrderAttributeName}"; + private const string MapGroupAttributeFullyQualifiedName = $"{AttributesNamespace}.{MapGroupAttributeName}"; + private const string SummaryAttributeFullyQualifiedName = $"{AttributesNamespace}.{SummaryAttributeName}"; + private const string EndpointFilterAttributeFullyQualifiedName = $"{AttributesNamespace}.{EndpointFilterAttributeName}"; + private const string AcceptsAttributeFullyQualifiedName = $"{AttributesNamespace}.{AcceptsAttributeName}"; + private const string ProducesResponseAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesResponseAttributeName}"; + private const string ProducesProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesProblemAttributeName}"; + private const string ProducesValidationProblemAttributeFullyQualifiedName = $"{AttributesNamespace}.{ProducesValidationProblemAttributeName}"; + private const string AddEndpointHandlersMethodFullyQualifiedName = $"{RoutingNamespace}.{AddEndpointHandlersMethodName}"; + private const string UseEndpointHandlersMethodFullyQualifiedName = $"{RoutingNamespace}.{UseEndpointHandlersMethodName}"; +} diff --git a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs new file mode 100644 index 0000000..60c5bbc --- /dev/null +++ b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs @@ -0,0 +1,33 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct EndpointConfiguration +{ + public required string? DisplayName { get; init; } + public required string? Summary { get; init; } + public required string? Description { get; init; } + public required EquatableImmutableArray? Tags { get; init; } + public required EquatableImmutableArray? Accepts { get; init; } + public required EquatableImmutableArray? Produces { get; init; } + public required EquatableImmutableArray? ProducesProblem { get; init; } + public required EquatableImmutableArray? ProducesValidationProblem { get; init; } + public required bool ExcludeFromDescription { get; init; } + public required bool RequireAuthorization { get; init; } + public required EquatableImmutableArray? AuthorizationPolicies { get; init; } + public required bool DisableAntiforgery { get; init; } + public required bool AllowAnonymous { get; init; } + public required bool RequireCors { get; init; } + public required string? CorsPolicyName { get; init; } + public required EquatableImmutableArray? RequiredHosts { get; init; } + public required bool RequireRateLimiting { get; init; } + public required string? RateLimitingPolicyName { get; init; } + public required EquatableImmutableArray? EndpointFilterTypes { get; init; } + public required bool ShortCircuit { get; init; } + public required bool DisableValidation { get; init; } + public required bool DisableRequestTimeout { get; init; } + public required bool WithRequestTimeout { get; init; } + public required string? RequestTimeoutPolicyName { get; init; } + public required int? Order { get; init; } + public required string? GroupIdentifier { get; init; } + public required string? GroupPattern { get; init; } + public required string? GroupName { get; init; } +} diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs new file mode 100644 index 0000000..f986ac9 --- /dev/null +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -0,0 +1,299 @@ +using System.Runtime.CompilerServices; +using Microsoft.CodeAnalysis; +using static GeneratedEndpoints.Common.Constants; + +namespace GeneratedEndpoints.Common; + +internal static class EndpointConfigurationFactory +{ + private static readonly ConditionalWeakTable GeneratedAttributeKindCache = new(); + + public static EndpointConfiguration Create(ISymbol symbol) + { + var attributes = symbol.GetAttributes(); + + string? displayName = null; + string? description = null; + EquatableImmutableArray? tags = null; + bool? requireAuthorization = null; + EquatableImmutableArray? authorizationPolicies = null; + bool? disableAntiforgery = null; + bool? allowAnonymous = null; + bool? excludeFromDescription = null; + List? accepts = null; + List? produces = null; + List? producesProblem = null; + List? producesValidationProblem = null; + bool? requireCors = null; + string? corsPolicyName = null; + EquatableImmutableArray? requiredHosts = null; + bool? requireRateLimiting = null; + string? rateLimitingPolicyName = null; + List? endpointFilters = null; + HashSet? endpointFilterSet = null; + bool? shortCircuit = null; + bool? disableValidation = null; + bool? disableRequestTimeout = null; + bool? withRequestTimeout = null; + string? requestTimeoutPolicyName = null; + int? order = null; + string? groupIdentifier = null; + string? groupPattern = null; + string? groupName = null; + string? summary = null; + + foreach (var attribute in attributes) + { + var attributeClass = attribute.AttributeClass; + if (attributeClass is null) + continue; + + var attributeKind = GetGeneratedAttributeKind(attributeClass); + switch (attributeKind) + { + case RequestHandlerAttributeKind.ShortCircuit: + shortCircuit = true; + continue; + case RequestHandlerAttributeKind.DisableValidation: + disableValidation = true; + continue; + case RequestHandlerAttributeKind.DisableRequestTimeout: + disableRequestTimeout = true; + continue; + case RequestHandlerAttributeKind.RequestTimeout: + requestTimeoutPolicyName = attribute.GetConstructorStringValue(); + withRequestTimeout = true; + continue; + case RequestHandlerAttributeKind.Order: + order = attribute.GetConstructorIntValue(); + continue; + case RequestHandlerAttributeKind.MapGroup: + groupIdentifier = GetMapGroupIdentifier(symbol); + groupPattern = attribute.GetConstructorStringValue() ?? ""; + groupName = attribute.GetNamedStringValue(NameAttributeNamedParameter); + continue; + case RequestHandlerAttributeKind.Summary: + summary = attribute.GetConstructorStringValue(); + continue; + case RequestHandlerAttributeKind.Accepts: + TryAddAcceptsMetadata(attribute, attributeClass, ref accepts); + continue; + case RequestHandlerAttributeKind.ProducesResponse: + TryAddProducesMetadata(attribute, attributeClass, ref produces); + continue; + case RequestHandlerAttributeKind.RequireAuthorization: + authorizationPolicies = attribute.GetConstructorStringArray(); + requireAuthorization = true; + continue; + case RequestHandlerAttributeKind.RequireCors: + corsPolicyName = attribute.GetConstructorStringValue(); + requireCors = true; + continue; + case RequestHandlerAttributeKind.RequireHost: + requiredHosts = attribute.GetConstructorStringArray(); + continue; + case RequestHandlerAttributeKind.RequireRateLimiting: + rateLimitingPolicyName = attribute.GetConstructorStringValue(); + requireRateLimiting = rateLimitingPolicyName is not null; + continue; + case RequestHandlerAttributeKind.EndpointFilter: + TryAddEndpointFilter(attribute, attributeClass, ref endpointFilters, ref endpointFilterSet); + continue; + case RequestHandlerAttributeKind.DisableAntiforgery: + disableAntiforgery = true; + continue; + case RequestHandlerAttributeKind.ProducesProblem: + { + var statusCode = attribute.GetConstructorIntValue() ?? 500; + var contentType = attribute.GetConstructorStringValue(1); + var additionalContentTypes = attribute.GetConstructorStringArray(2); + var producesProblemMetadata = new ProducesProblemMetadata(statusCode, contentType, additionalContentTypes); + + var producesProblemList = producesProblem ??= []; + producesProblemList.Add(producesProblemMetadata); + continue; + } + case RequestHandlerAttributeKind.ProducesValidationProblem: + { + var statusCode = attribute.GetConstructorIntValue() ?? 400; + var contentType = attribute.GetConstructorStringValue(1); + var additionalContentTypes = attribute.GetConstructorStringArray(2); + var producesValidationProblemMetadata = new ProducesValidationProblemMetadata(statusCode, contentType, additionalContentTypes); + + var producesValidationProblemList = producesValidationProblem ??= []; + producesValidationProblemList.Add(producesValidationProblemMetadata); + continue; + } + case RequestHandlerAttributeKind.DisplayName: + displayName = attribute.GetConstructorStringValue(); + break; + case RequestHandlerAttributeKind.Description: + description = attribute.GetConstructorStringValue(); + break; + case RequestHandlerAttributeKind.AllowAnonymous: + allowAnonymous = true; + break; + case RequestHandlerAttributeKind.Tags: + tags = attribute.GetConstructorStringArray(); + break; + case RequestHandlerAttributeKind.ExcludeFromDescription: + excludeFromDescription = true; + break; + case RequestHandlerAttributeKind.None: + default: + break; + } + } + + return new EndpointConfiguration + { + DisplayName = displayName, + Summary = summary, + Description = description, + Tags = tags, + Accepts = ToEquatableOrNull(accepts), + Produces = ToEquatableOrNull(produces), + ProducesProblem = ToEquatableOrNull(producesProblem), + ProducesValidationProblem = ToEquatableOrNull(producesValidationProblem), + ExcludeFromDescription = excludeFromDescription ?? false, + RequireAuthorization = requireAuthorization ?? false, + AuthorizationPolicies = authorizationPolicies, + DisableAntiforgery = disableAntiforgery ?? false, + AllowAnonymous = allowAnonymous ?? false, + RequireCors = requireCors ?? false, + CorsPolicyName = corsPolicyName, + RequiredHosts = requiredHosts, + RequireRateLimiting = requireRateLimiting ?? false, + RateLimitingPolicyName = rateLimitingPolicyName, + EndpointFilterTypes = ToEquatableOrNull(endpointFilters), + ShortCircuit = shortCircuit ?? false, + DisableValidation = disableValidation ?? false, + DisableRequestTimeout = disableRequestTimeout ?? false, + WithRequestTimeout = withRequestTimeout ?? false, + RequestTimeoutPolicyName = requestTimeoutPolicyName, + Order = order, + GroupIdentifier = groupIdentifier, + GroupPattern = groupPattern, + GroupName = groupName, + }; + } + + private static string? GetMapGroupIdentifier(ISymbol symbol) + { + if (symbol is not INamedTypeSymbol namedTypeSymbol) + return null; + + var className = namedTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + if (className.StartsWith(GlobalPrefix, StringComparison.Ordinal)) + className = className[GlobalPrefix.Length..]; + + var builder = StringBuilderPool.Get(className.Length + 8); + builder.Append('_'); + + foreach (var character in className) + builder.Append(char.IsLetterOrDigit(character) ? character : '_'); + + builder.Append("_Group"); + return StringBuilderPool.ToStringAndReturn(builder); + } + + private static RequestHandlerAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) + { + var definition = attributeClass.OriginalDefinition; + var cacheEntry = GeneratedAttributeKindCache.GetValue( + definition, static def => new GeneratedAttributeKindCacheEntry(def.GetRequestHandlerAttributeKind()) + ); + + return cacheEntry.Kind; + } + + private static EquatableImmutableArray? ToEquatableOrNull(List? values) + { + return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null; + } + + private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? accepts) + { + string? requestType; + if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) + requestType = attributeClass.TypeArguments[0] + .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + else if (attribute.GetNamedTypeSymbol(RequestTypeAttributeNamedParameter) is { } requestTypeSymbol) + requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + else + return; + + var contentType = attribute.GetConstructorStringValue() ?? ApplicationJsonContentType; + var additionalContentTypes = attribute.GetConstructorStringArray(1); + var isOptional = attribute.GetNamedBoolValue(IsOptionalAttributeNamedParameter); + + var acceptMetadata = new AcceptsMetadata(requestType, contentType, additionalContentTypes, isOptional); + + var acceptsList = accepts ??= []; + acceptsList.Add(acceptMetadata); + } + + private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? produces) + { + string? responseType; + if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) + responseType = attributeClass.TypeArguments[0] + .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + else if (attribute.GetNamedTypeSymbol(ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol) + responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + else + return; + + var statusCode = attribute.GetConstructorIntValue() ?? 200; + var contentType = attribute.GetConstructorStringValue(1); + var additionalContentTypes = attribute.GetConstructorStringArray(2); + + var producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes); + + var producesList = produces ??= []; + producesList.Add(producesMetadata); + } + + private static void TryAddEndpointFilter( + AttributeData attribute, + INamedTypeSymbol attributeClass, + ref List? endpointFilters, + ref HashSet? endpointFilterSet + ) + { + if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) + { + TryAddEndpointFilterType(attributeClass.TypeArguments[0], ref endpointFilters, ref endpointFilterSet); + return; + } + + if (attribute.ConstructorArguments.Length == 0) + return; + + if (attribute.ConstructorArguments[0].Value is ITypeSymbol filterTypeSymbol) + TryAddEndpointFilterType(filterTypeSymbol, ref endpointFilters, ref endpointFilterSet); + } + + private static void TryAddEndpointFilterType(ITypeSymbol? typeSymbol, ref List? endpointFilters, ref HashSet? endpointFilterSet) + { + if (typeSymbol is null or ITypeParameterSymbol or IErrorTypeSymbol) + return; + + var displayString = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + if (string.IsNullOrWhiteSpace(displayString)) + return; + + endpointFilterSet ??= new HashSet(StringComparer.Ordinal); + if (!endpointFilterSet.Add(displayString)) + return; + + endpointFilters ??= []; + endpointFilters.Add(displayString); + } + + private sealed class GeneratedAttributeKindCacheEntry(RequestHandlerAttributeKind kind) + { + public RequestHandlerAttributeKind Kind { get; } = kind; + } +} diff --git a/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs b/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs index be499bf..9270424 100644 --- a/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs +++ b/src/GeneratedEndpoints/Common/EquatableImmutableArray`1.cs @@ -1,5 +1,6 @@ using System.Collections; using System.Collections.Immutable; +using System.Runtime.InteropServices; namespace GeneratedEndpoints.Common; @@ -26,6 +27,35 @@ internal EquatableImmutableArray(ImmutableArray? array) _array = array; } + /// + /// Gets the underlying array, or if the array is empty. WARNING: This returns the underlying storage of the ImmutableArray and + /// must only be used when you can guarantee no other code observes it as immutable. + /// + internal T[]? AsArray() + { + return ImmutableCollectionsMarshal.AsArray(Array); + } + + /// + /// Sorts the underlying array in place using the specified comparer. WARNING: This mutates the underlying storage of the ImmutableArray and must only be + /// used when you can guarantee no other code observes it as immutable. + /// + internal void SortInPlace(IComparer? comparer = null) + { + if (_array is null) + return; + + var array = _array.Value; + if (array.Length <= 1) + return; + + comparer ??= Comparer.Default; + + var raw = ImmutableCollectionsMarshal.AsArray(array); + if (raw is not null) + System.Array.Sort(raw, comparer); + } + /// public bool Equals(EquatableImmutableArray other) { diff --git a/src/GeneratedEndpoints/Common/HttpAttributeDefinition.cs b/src/GeneratedEndpoints/Common/HttpAttributeDefinition.cs new file mode 100644 index 0000000..6c757de --- /dev/null +++ b/src/GeneratedEndpoints/Common/HttpAttributeDefinition.cs @@ -0,0 +1,5 @@ +using Microsoft.CodeAnalysis.Text; + +namespace GeneratedEndpoints.Common; + +internal readonly record struct HttpAttributeDefinition(string Name, string FullyQualifiedName, string Hint, string Verb, SourceText SourceText); diff --git a/src/GeneratedEndpoints/Common/IncrementalValueProviderExtensions.cs b/src/GeneratedEndpoints/Common/IncrementalValueProviderExtensions.cs index f2b97e7..46f3976 100644 --- a/src/GeneratedEndpoints/Common/IncrementalValueProviderExtensions.cs +++ b/src/GeneratedEndpoints/Common/IncrementalValueProviderExtensions.cs @@ -42,13 +42,14 @@ IncrementalValueProvider provider4 .Select((tuple, _) => (tuple.Left.Left.Left, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right)); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5)> Combine( - this IncrementalValueProvider provider1, - IncrementalValueProvider provider2, - IncrementalValueProvider provider3, - IncrementalValueProvider provider4, - IncrementalValueProvider provider5 - ) + public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5)> + Combine( + this IncrementalValueProvider provider1, + IncrementalValueProvider provider2, + IncrementalValueProvider provider3, + IncrementalValueProvider provider4, + IncrementalValueProvider provider5 + ) { return provider1.Combine(provider2) .Combine(provider3) @@ -57,32 +58,36 @@ IncrementalValueProvider provider5 .Select((tuple, _) => (tuple.Left.Left.Left.Left, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right)); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6)> Combine( - this IncrementalValueProvider provider1, - IncrementalValueProvider provider2, - IncrementalValueProvider provider3, - IncrementalValueProvider provider4, - IncrementalValueProvider provider5, - IncrementalValueProvider provider6 - ) + public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6)> + Combine( + this IncrementalValueProvider provider1, + IncrementalValueProvider provider2, + IncrementalValueProvider provider3, + IncrementalValueProvider provider4, + IncrementalValueProvider provider5, + IncrementalValueProvider provider6 + ) { return provider1.Combine(provider2) .Combine(provider3) .Combine(provider4) .Combine(provider5) .Combine(provider6) - .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right)); + .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, + tuple.Left.Right, tuple.Right) + ); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7)> Combine( - this IncrementalValueProvider provider1, - IncrementalValueProvider provider2, - IncrementalValueProvider provider3, - IncrementalValueProvider provider4, - IncrementalValueProvider provider5, - IncrementalValueProvider provider6, - IncrementalValueProvider provider7 - ) + public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7)> + Combine( + this IncrementalValueProvider provider1, + IncrementalValueProvider provider2, + IncrementalValueProvider provider3, + IncrementalValueProvider provider4, + IncrementalValueProvider provider5, + IncrementalValueProvider provider6, + IncrementalValueProvider provider7 + ) { return provider1.Combine(provider2) .Combine(provider3) @@ -90,19 +95,22 @@ IncrementalValueProvider provider7 .Combine(provider5) .Combine(provider6) .Combine(provider7) - .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right)); + .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Right, + tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) + ); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8)> Combine( - this IncrementalValueProvider provider1, - IncrementalValueProvider provider2, - IncrementalValueProvider provider3, - IncrementalValueProvider provider4, - IncrementalValueProvider provider5, - IncrementalValueProvider provider6, - IncrementalValueProvider provider7, - IncrementalValueProvider provider8 - ) + public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8)> + Combine( + this IncrementalValueProvider provider1, + IncrementalValueProvider provider2, + IncrementalValueProvider provider3, + IncrementalValueProvider provider4, + IncrementalValueProvider provider5, + IncrementalValueProvider provider6, + IncrementalValueProvider provider7, + IncrementalValueProvider provider8 + ) { return provider1.Combine(provider2) .Combine(provider3) @@ -111,12 +119,13 @@ IncrementalValueProvider provider8 .Combine(provider6) .Combine(provider7) .Combine(provider8) - .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, - tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) + .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Right, + tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) ); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8, TItem9 Item9)> + public static + IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8, TItem9 Item9)> Combine( this IncrementalValueProvider provider1, IncrementalValueProvider provider2, @@ -137,13 +146,15 @@ IncrementalValueProvider provider9 .Combine(provider7) .Combine(provider8) .Combine(provider9) - .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Right, - tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) + .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Left.Right, + tuple.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, + tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) ); } - public static IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8, TItem9 Item9, TItem10 Item10)> - Combine( + public static + IncrementalValueProvider<(TItem1 Item1, TItem2 Item2, TItem3 Item3, TItem4 Item4, TItem5 Item5, TItem6 Item6, TItem7 Item7, TItem8 Item8, TItem9 Item9, + TItem10 Item10)> Combine( this IncrementalValueProvider provider1, IncrementalValueProvider provider2, IncrementalValueProvider provider3, @@ -165,8 +176,9 @@ IncrementalValueProvider provider10 .Combine(provider8) .Combine(provider9) .Combine(provider10) - .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Left.Right, - tuple.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) + .Select((tuple, _) => (tuple.Left.Left.Left.Left.Left.Left.Left.Left.Left, tuple.Left.Left.Left.Left.Left.Left.Left.Left.Right, + tuple.Left.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Left.Left.Right, + tuple.Left.Left.Left.Left.Right, tuple.Left.Left.Left.Right, tuple.Left.Left.Right, tuple.Left.Right, tuple.Right) ); } } diff --git a/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs b/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs new file mode 100644 index 0000000..bb4cdfc --- /dev/null +++ b/src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs @@ -0,0 +1,126 @@ +using System.Collections.Immutable; +using Microsoft.CodeAnalysis; +using static GeneratedEndpoints.Common.AttributeSymbolMatcher; +using static GeneratedEndpoints.Common.Constants; + +namespace GeneratedEndpoints.Common; + +// ReSharper disable ForCanBeConvertedToForeach +// ReSharper disable LoopCanBeConvertedToQuery +// Do not refactor, use for loop to avoid allocations. + +internal static class MethodSymbolExtensions +{ + public static EquatableImmutableArray GetParameters(this IMethodSymbol methodSymbol, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + var methodParameters = ImmutableArray.CreateBuilder(methodSymbol.Parameters.Length); + + for (var index = 0; index < methodSymbol.Parameters.Length; index++) + { + cancellationToken.ThrowIfCancellationRequested(); + + var parameterSymbol = methodSymbol.Parameters[index]; + var parameterName = parameterSymbol.Name; + var parameterType = parameterSymbol.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var bindingPrefix = GetBindingPrefix(parameterSymbol); + var parameter = new Parameter(parameterName, parameterType, bindingPrefix); + + methodParameters.Add(parameter); + } + + return methodParameters.ToEquatableImmutable(); + } + + private static string GetBindingPrefix(IParameterSymbol parameter) + { + var source = BindingSource.None; + TypedConstant? typedKey = null; + string? bindingName = null; + + foreach (var attribute in parameter.GetAttributes()) + { + var attributeClass = attribute.AttributeClass; + if (attributeClass is null) + continue; + + var attributeSource = GetBindingSourceFromAttributeClass(attributeClass); + if (attributeSource == BindingSource.None) + continue; + + source = attributeSource; + switch (attributeSource) + { + case BindingSource.FromRoute: + case BindingSource.FromQuery: + case BindingSource.FromHeader: + case BindingSource.FromForm: + bindingName = attribute.GetNamedStringValue(NameAttributeNamedParameter); + break; + case BindingSource.FromKeyedServices: + typedKey = attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0] : null; + break; + } + } + + var bindingPrefix = GetBindingSourceAttribute(source, typedKey, bindingName); + + return bindingPrefix; + } + + private static BindingSource GetBindingSourceFromAttributeClass(INamedTypeSymbol attributeClass) + { + var definition = attributeClass.OriginalDefinition; + var namespaceSymbol = definition.ContainingNamespace; + + return definition.Name switch + { + "FromRouteAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromRoute, + "FromQueryAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromQuery, + "FromHeaderAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromHeader, + "FromBodyAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromBody, + "FromFormAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromForm, + "FromServicesAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromServices, + "FromKeyedServicesAttribute" when IsInNamespace(namespaceSymbol, ExtensionsDependencyInjectionNamespaceParts) => BindingSource.FromKeyedServices, + "AsParametersAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreHttpNamespaceParts) => BindingSource.AsParameters, + _ => BindingSource.None, + }; + } + + private static string GetBindingSourceAttribute(BindingSource source, TypedConstant? typedKey, string? bindingName) + { + switch (source) + { + case BindingSource.None: + return ""; + case BindingSource.FromRoute: + return FormatBindingAttribute("FromRoute", bindingName); + case BindingSource.FromQuery: + return FormatBindingAttribute("FromQuery", bindingName); + case BindingSource.FromHeader: + return FormatBindingAttribute("FromHeader", bindingName); + case BindingSource.FromBody: + return FormatBindingAttribute("FromBody", bindingName); + case BindingSource.FromForm: + return FormatBindingAttribute("FromForm", bindingName); + case BindingSource.FromServices: + return "[FromServices] "; + case BindingSource.FromKeyedServices: + var key = typedKey?.ToConstLiteral(); + return $"[FromKeyedServices({key})] "; + case BindingSource.AsParameters: + return "[AsParameters] "; + default: + return ""; + } + } + + private static string FormatBindingAttribute(string attributeName, string? bindingName) + { + if (bindingName is null) + return $"[{attributeName}] "; + + return $"[{attributeName}(Name = {bindingName.ToStringLiteral()})] "; + } +} diff --git a/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs b/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs index d2e5370..36c1bca 100644 --- a/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs +++ b/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs @@ -1,19 +1,51 @@ using Microsoft.CodeAnalysis; +using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints.Common; /// Provides extension methods for working with named type symbols. internal static class NamedTypeSymbolExtensions { - public static bool HasTypeArgument(this INamedTypeSymbol namedTypeSymbol, out ITypeSymbol typeParameter) + public static RequestHandlerAttributeKind GetRequestHandlerAttributeKind(this INamedTypeSymbol definition) { - if (namedTypeSymbol.TypeArguments.Length == 1) - { - typeParameter = namedTypeSymbol.TypeArguments[0]; - return true; - } + if (AttributeSymbolMatcher.IsAttribute(definition, DisplayNameAttributeName, ComponentModelNamespaceParts)) + return RequestHandlerAttributeKind.DisplayName; + + if (AttributeSymbolMatcher.IsAttribute(definition, DescriptionAttributeName, ComponentModelNamespaceParts)) + return RequestHandlerAttributeKind.Description; + + if (AttributeSymbolMatcher.IsAttribute(definition, AllowAnonymousAttributeName, AspNetCoreAuthorizationNamespaceParts)) + return RequestHandlerAttributeKind.AllowAnonymous; + + if (AttributeSymbolMatcher.IsAttribute(definition, TagsAttributeName, AspNetCoreHttpNamespaceParts)) + return RequestHandlerAttributeKind.Tags; - typeParameter = null!; - return false; + if (AttributeSymbolMatcher.IsAttribute(definition, ExcludeFromDescriptionAttributeName, AspNetCoreRoutingNamespaceParts)) + return RequestHandlerAttributeKind.ExcludeFromDescription; + + if (!AttributeSymbolMatcher.IsInNamespace(definition.ContainingNamespace, AttributesNamespaceParts)) + return RequestHandlerAttributeKind.None; + + return definition.Name switch + { + ShortCircuitAttributeName => RequestHandlerAttributeKind.ShortCircuit, + DisableValidationAttributeName => RequestHandlerAttributeKind.DisableValidation, + DisableRequestTimeoutAttributeName => RequestHandlerAttributeKind.DisableRequestTimeout, + RequestTimeoutAttributeName => RequestHandlerAttributeKind.RequestTimeout, + OrderAttributeName => RequestHandlerAttributeKind.Order, + MapGroupAttributeName => RequestHandlerAttributeKind.MapGroup, + SummaryAttributeName => RequestHandlerAttributeKind.Summary, + AcceptsAttributeName => RequestHandlerAttributeKind.Accepts, + ProducesResponseAttributeName => RequestHandlerAttributeKind.ProducesResponse, + RequireAuthorizationAttributeName => RequestHandlerAttributeKind.RequireAuthorization, + RequireCorsAttributeName => RequestHandlerAttributeKind.RequireCors, + RequireHostAttributeName => RequestHandlerAttributeKind.RequireHost, + RequireRateLimitingAttributeName => RequestHandlerAttributeKind.RequireRateLimiting, + EndpointFilterAttributeName => RequestHandlerAttributeKind.EndpointFilter, + DisableAntiforgeryAttributeName => RequestHandlerAttributeKind.DisableAntiforgery, + ProducesProblemAttributeName => RequestHandlerAttributeKind.ProducesProblem, + ProducesValidationProblemAttributeName => RequestHandlerAttributeKind.ProducesValidationProblem, + _ => RequestHandlerAttributeKind.None, + }; } } diff --git a/src/GeneratedEndpoints/Common/Parameter.cs b/src/GeneratedEndpoints/Common/Parameter.cs new file mode 100644 index 0000000..6e4e53f --- /dev/null +++ b/src/GeneratedEndpoints/Common/Parameter.cs @@ -0,0 +1,3 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct Parameter(string Name, string Type, string BindingPrefix); diff --git a/src/GeneratedEndpoints/Common/ProducesMetadata.cs b/src/GeneratedEndpoints/Common/ProducesMetadata.cs new file mode 100644 index 0000000..e7f223c --- /dev/null +++ b/src/GeneratedEndpoints/Common/ProducesMetadata.cs @@ -0,0 +1,8 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct ProducesMetadata( + string ResponseType, + int StatusCode, + string? ContentType, + EquatableImmutableArray? AdditionalContentTypes +); diff --git a/src/GeneratedEndpoints/Common/ProducesProblemMetadata.cs b/src/GeneratedEndpoints/Common/ProducesProblemMetadata.cs new file mode 100644 index 0000000..443f70b --- /dev/null +++ b/src/GeneratedEndpoints/Common/ProducesProblemMetadata.cs @@ -0,0 +1,3 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct ProducesProblemMetadata(int StatusCode, string? ContentType, EquatableImmutableArray? AdditionalContentTypes); diff --git a/src/GeneratedEndpoints/Common/ProducesValidationProblemMetadata.cs b/src/GeneratedEndpoints/Common/ProducesValidationProblemMetadata.cs new file mode 100644 index 0000000..2234852 --- /dev/null +++ b/src/GeneratedEndpoints/Common/ProducesValidationProblemMetadata.cs @@ -0,0 +1,3 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct ProducesValidationProblemMetadata(int StatusCode, string? ContentType, EquatableImmutableArray? AdditionalContentTypes); diff --git a/src/GeneratedEndpoints/Common/RequestHandler.cs b/src/GeneratedEndpoints/Common/RequestHandler.cs new file mode 100644 index 0000000..3ca8301 --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandler.cs @@ -0,0 +1,41 @@ +using static GeneratedEndpoints.Common.Constants; + +namespace GeneratedEndpoints.Common; + +internal record struct RequestHandler +{ + public required RequestHandlerClass Class { get; init; } + public required RequestHandlerMethod Method { get; init; } + public required string HttpMethod { get; init; } + public required string Pattern { get; init; } + + public required string? Name + { + readonly get => _name; + init => _name = value; + } + + private string? _name; + + public void SetFullyQualifiedName() + { + ReadOnlySpan className = Class.Name; + ReadOnlySpan methodName = Method.Name; + + if (className.StartsWith(GlobalPrefix)) + className = className[GlobalPrefix.Length..]; + + var classLen = className.Length; + var methodLen = methodName.Length; + var total = classLen + 1 + methodLen; + + Span buffer = stackalloc char[total]; + className.CopyTo(buffer); + + buffer[classLen] = '.'; + + methodName.CopyTo(buffer[(classLen + 1)..]); + + _name = buffer.ToString(); + } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerAttributeKind.cs b/src/GeneratedEndpoints/Common/RequestHandlerAttributeKind.cs new file mode 100644 index 0000000..1a80118 --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerAttributeKind.cs @@ -0,0 +1,28 @@ +namespace GeneratedEndpoints.Common; + +internal enum RequestHandlerAttributeKind +{ + None = 0, + ShortCircuit, + DisableValidation, + DisableRequestTimeout, + RequestTimeout, + Order, + MapGroup, + Summary, + Accepts, + ProducesResponse, + RequireAuthorization, + RequireCors, + RequireHost, + RequireRateLimiting, + EndpointFilter, + DisableAntiforgery, + ProducesProblem, + ProducesValidationProblem, + DisplayName, + Description, + AllowAnonymous, + Tags, + ExcludeFromDescription, +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClass.cs b/src/GeneratedEndpoints/Common/RequestHandlerClass.cs new file mode 100644 index 0000000..c5e868b --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerClass.cs @@ -0,0 +1,9 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct RequestHandlerClass( + string Name, + bool IsStatic, + bool HasConfigureMethod, + bool ConfigureMethodAcceptsServiceProvider, + EndpointConfiguration Configuration +); diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs new file mode 100644 index 0000000..3226132 --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs @@ -0,0 +1,157 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using static GeneratedEndpoints.Common.Constants; + +namespace GeneratedEndpoints.Common; + +internal sealed class RequestHandlerClassCacheEntry +{ + private readonly object _lock = new(); + private RequestHandlerClass _value; + private bool _initialized; + + public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, CancellationToken cancellationToken) + { + if (_initialized) + return _value; + + lock (_lock) + { + if (_initialized) + return _value; + + cancellationToken.ThrowIfCancellationRequested(); + + var name = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var isStatic = classSymbol.IsStatic; + var configureMethodDetails = GetConfigureMethodDetails(classSymbol, cancellationToken); + + var classConfiguration = EndpointConfigurationFactory.Create(classSymbol); + + _value = new RequestHandlerClass(name, isStatic, configureMethodDetails.HasConfigureMethod, + configureMethodDetails.ConfigureMethodAcceptsServiceProvider, classConfiguration + ); + _initialized = true; + return _value; + } + } + + private static ConfigureMethodDetails GetConfigureMethodDetails(INamedTypeSymbol classSymbol, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + var hasConfigureMethod = false; + var acceptsServiceProvider = false; + foreach (var member in classSymbol.GetMembers(ConfigureMethodName)) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (member is not IMethodSymbol methodSymbol) + continue; + + if (IsConfigureMethod(methodSymbol, out var methodAcceptsServiceProvider)) + { + hasConfigureMethod = true; + if (methodAcceptsServiceProvider) + { + acceptsServiceProvider = true; + break; + } + } + } + + return new ConfigureMethodDetails(hasConfigureMethod, acceptsServiceProvider); + } + + private static bool IsConfigureMethod(IMethodSymbol methodSymbol, out bool acceptsServiceProvider) + { + acceptsServiceProvider = false; + + if (!methodSymbol.IsStatic) + return false; + + if (methodSymbol.TypeParameters.Length != 1) + return false; + + if (methodSymbol.Parameters.Length is < 1 or > 2) + return false; + + var builderTypeParameter = methodSymbol.TypeParameters[0]; + var builderParameter = methodSymbol.Parameters[0]; + + if (!SymbolEqualityComparer.Default.Equals(builderParameter.Type, builderTypeParameter)) + return false; + + if (methodSymbol.Parameters.Length == 2) + { + var serviceProviderParameter = methodSymbol.Parameters[1]; + if (!IsServiceProviderParameter(serviceProviderParameter.Type)) + return false; + + acceptsServiceProvider = true; + } + + if (!methodSymbol.ReturnsVoid) + return false; + + if (!HasEndpointConventionBuilderConstraint(builderTypeParameter, methodSymbol)) + return false; + + return true; + } + + private static bool IsServiceProviderParameter(ITypeSymbol typeSymbol) + { + return MatchesServiceProvider(typeSymbol); + } + + private static bool HasEndpointConventionBuilderConstraint(ITypeParameterSymbol builderTypeParameter, IMethodSymbol methodSymbol) + { + var symbolMatches = builderTypeParameter.ConstraintTypes.Any(MatchesEndpointConventionBuilder); + if (symbolMatches) + return true; + + return methodSymbol.DeclaringSyntaxReferences + .Select(reference => reference.GetSyntax()) + .OfType() + .SelectMany(methodSyntax => methodSyntax.ConstraintClauses) + .Where(clause => string.Equals(clause.Name.Identifier.ValueText, builderTypeParameter.Name, StringComparison.Ordinal)) + .SelectMany(clause => clause.Constraints.OfType()) + .Any(constraint => IsEndpointConventionBuilderIdentifier(constraint.Type)); + } + + private static bool IsEndpointConventionBuilderIdentifier(TypeSyntax typeSyntax) + { + return typeSyntax switch + { + QualifiedNameSyntax qualified => IsEndpointConventionBuilderIdentifier(qualified.Right), + AliasQualifiedNameSyntax alias => IsEndpointConventionBuilderIdentifier(alias.Name), + SimpleNameSyntax simple => string.Equals(simple.Identifier.ValueText, "IEndpointConventionBuilder", StringComparison.Ordinal), + _ => false, + }; + } + + private static bool MatchesEndpointConventionBuilder(ITypeSymbol typeSymbol) + { + if (typeSymbol is not INamedTypeSymbol namedType) + return false; + + if (!string.Equals(namedType.Name, "IEndpointConventionBuilder", StringComparison.Ordinal)) + return false; + + var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? ""; + return string.Equals(containingNamespace, "Microsoft.AspNetCore.Builder", StringComparison.Ordinal); + } + + private static bool MatchesServiceProvider(ITypeSymbol typeSymbol) + { + if (typeSymbol is not INamedTypeSymbol namedType) + return false; + + if (!string.Equals(namedType.Name, "IServiceProvider", StringComparison.Ordinal)) + return false; + + var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? ""; + return string.Equals(containingNamespace, "System", StringComparison.Ordinal); + } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerComparer.cs b/src/GeneratedEndpoints/Common/RequestHandlerComparer.cs new file mode 100644 index 0000000..4d3d61a --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerComparer.cs @@ -0,0 +1,23 @@ +namespace GeneratedEndpoints.Common; + +internal sealed class RequestHandlerComparer : IComparer +{ + public static RequestHandlerComparer Instance { get; } = new(); + + public int Compare(RequestHandler x, RequestHandler y) + { + var comparison = string.Compare(x.Class.Name, y.Class.Name, StringComparison.Ordinal); + if (comparison != 0) + return comparison; + + comparison = string.Compare(x.Method.Name, y.Method.Name, StringComparison.Ordinal); + if (comparison != 0) + return comparison; + + comparison = string.Compare(x.HttpMethod, y.HttpMethod, StringComparison.Ordinal); + if (comparison != 0) + return comparison; + + return string.Compare(x.Pattern, y.Pattern, StringComparison.Ordinal); + } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs new file mode 100644 index 0000000..54e2958 --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs @@ -0,0 +1,8 @@ +namespace GeneratedEndpoints.Common; + +internal readonly record struct RequestHandlerMethod( + string Name, + bool IsStatic, + EquatableImmutableArray Parameters, + EndpointConfiguration Configuration +); diff --git a/src/GeneratedEndpoints/Common/StringExtensions.cs b/src/GeneratedEndpoints/Common/StringExtensions.cs new file mode 100644 index 0000000..5867811 --- /dev/null +++ b/src/GeneratedEndpoints/Common/StringExtensions.cs @@ -0,0 +1,78 @@ +using System.Globalization; + +namespace GeneratedEndpoints.Common; + +internal static class StringExtensions +{ + public static string ToStringLiteral(this string? value) + { + if (value is null) + return "null"; + + var firstEscapeIndex = -1; + for (var i = 0; i < value.Length; i++) + { + var c = value[i]; + if (c == '\"' || c == '\\' || c == '\n' || c == '\r' || c == '\t' || c == '\0' || char.IsControl(c)) + { + firstEscapeIndex = i; + break; + } + } + + if (firstEscapeIndex < 0) + return string.Concat("\"", value, "\""); + + var sb = StringBuilderPool.Get(value.Length + 2); + sb.Append('"'); + if (firstEscapeIndex > 0) + sb.Append(value, 0, firstEscapeIndex); + + for (var i = firstEscapeIndex; i < value.Length; i++) + { + var c = value[i]; + switch (c) + { + case '\"': + sb.Append("\\\""); + break; + case '\\': + sb.Append("\\\\"); + break; + case '\n': + sb.Append("\\n"); + break; + case '\r': + sb.Append("\\r"); + break; + case '\t': + sb.Append("\\t"); + break; + case '\0': + sb.Append("\\0"); + break; + default: + if (char.IsControl(c)) + sb.Append("\\u") + .Append(((int)c).ToString("x4", CultureInfo.InvariantCulture)); + else + sb.Append(c); + + break; + } + } + + sb.Append('"'); + return StringBuilderPool.ToStringAndReturn(sb); + } + + public static string? NormalizeOptionalString(this string? value) + { + return string.IsNullOrWhiteSpace(value) ? null : value!.Trim(); + } + + public static string NormalizeOrDefaultString(this string? value, string defaultValue) + { + return string.IsNullOrWhiteSpace(value) ? defaultValue : value!.Trim(); + } +} diff --git a/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs b/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs index 1047514..0b69f45 100644 --- a/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs +++ b/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs @@ -4,79 +4,47 @@ namespace GeneratedEndpoints.Common; internal static class TypeSymbolExtensions { - public static bool IsValueTask(this ITypeSymbol symbol, out INamedTypeSymbol valueTaskSymbol) + public static bool IsAwaitable(this ITypeSymbol symbol) { - if (symbol is INamedTypeSymbol - { - MetadataName: "ValueTask`1", - ContainingNamespace: + return symbol switch + { + INamedTypeSymbol { - Name: "Tasks", + MetadataName: "ValueTask`1", ContainingNamespace: { - Name: "Threading", - ContainingNamespace: - { - Name: "System", - ContainingNamespace.IsGlobalNamespace: true, - }, + Name: "Tasks", + ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, }, - }, - } namedTypeSymbol) - { - valueTaskSymbol = namedTypeSymbol; - return true; - } - - valueTaskSymbol = null!; - return false; - } - - public static bool IsTask(this ITypeSymbol symbol, out INamedTypeSymbol valueTaskSymbol) - { - if (symbol is INamedTypeSymbol - { - MetadataName: "Task`1", - ContainingNamespace: + } + or INamedTypeSymbol { - Name: "Tasks", + MetadataName: "Task`1", ContainingNamespace: { - Name: "Threading", - ContainingNamespace: - { - Name: "System", - ContainingNamespace.IsGlobalNamespace: true, - }, + Name: "Tasks", + ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, }, - }, - } namedTypeSymbol) - { - valueTaskSymbol = namedTypeSymbol; - return true; - } - - valueTaskSymbol = null!; - return false; - } - - public static bool IsLightResults(this ITypeSymbol symbol, out INamedTypeSymbol lightResultsSymbol) - { - if (symbol is INamedTypeSymbol - { - Name: "Result", - ContainingNamespace: + } + or INamedTypeSymbol { - Name: "LightResults", - ContainingNamespace.IsGlobalNamespace: true, - }, - } namedTypeSymbol) - { - lightResultsSymbol = namedTypeSymbol; - return true; - } - - lightResultsSymbol = null!; - return false; + MetadataName: "ValueTask", + ContainingNamespace: + { + Name: "Tasks", + ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, + }, + } + or INamedTypeSymbol + { + MetadataName: "Task", + ContainingNamespace: + { + Name: "Tasks", + ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, + }, + } => true, + _ => false, + }; } } diff --git a/src/GeneratedEndpoints/Common/TypedConstantExtensions.cs b/src/GeneratedEndpoints/Common/TypedConstantExtensions.cs new file mode 100644 index 0000000..1f29648 --- /dev/null +++ b/src/GeneratedEndpoints/Common/TypedConstantExtensions.cs @@ -0,0 +1,82 @@ +using System.Diagnostics.CodeAnalysis; +using System.Globalization; +using Microsoft.CodeAnalysis; + +namespace GeneratedEndpoints.Common; + +internal static class TypedConstantExtensions +{ + [SuppressMessage("Globalization", "CA1308: Normalize strings to uppercase", Justification = "C# boolean literals must be lowercase.")] + public static string ToConstLiteral(this TypedConstant tc) + { + if (tc.IsNull) + return "null"; + var v = tc.Value; + var t = tc.Type; + if (t is null) + return "null"; + + if (t.TypeKind != TypeKind.Enum) + return t.SpecialType switch + { + SpecialType.System_String => ((string?)v).ToStringLiteral(), + SpecialType.System_Char => $"'{EscapeChar((char)v!)}'", + SpecialType.System_Boolean => ((bool)v!).ToString() + .ToLowerInvariant(), + SpecialType.System_Double => ((double)v!).ToString("R", CultureInfo.InvariantCulture), + SpecialType.System_Single => ((float)v!).ToString("R", CultureInfo.InvariantCulture) + "f", + SpecialType.System_Decimal => ((decimal)v!).ToString(CultureInfo.InvariantCulture) + "m", + SpecialType.System_SByte => ((sbyte)v!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Byte => ((byte)v!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Int16 => ((short)v!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_UInt16 => ((ushort)v!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Int32 => ((int)v!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_UInt32 => ((uint)v!).ToString(CultureInfo.InvariantCulture) + "u", + SpecialType.System_Int64 => ((long)v!).ToString(CultureInfo.InvariantCulture) + "L", + SpecialType.System_UInt64 => ((ulong)v!).ToString(CultureInfo.InvariantCulture) + "UL", + _ => (v?.ToString()).ToStringLiteral(), + }; + + var field = t.GetMembers() + .OfType() + .FirstOrDefault(f => f.HasConstantValue && Equals(f.ConstantValue, v)); + + if (field is not null) + return $"{t.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}.{field.Name}"; + + var underlying = ((INamedTypeSymbol)t).EnumUnderlyingType!; + var num = IntegralLiteral(v, underlying.SpecialType); + return $"({t.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}){num}"; + } + + private static string EscapeChar(char c) + { + return c switch + { + '\'' => "\\'", + '\\' => "\\\\", + '\n' => "\\n", + '\r' => "\\r", + '\t' => "\\t", + '\0' => "\\0", + _ when char.IsControl(c) => "\\u" + ((int)c).ToString("x4", CultureInfo.InvariantCulture), + _ => c.ToString(), + }; + } + + private static string IntegralLiteral(object? value, SpecialType underlying) + { + return underlying switch + { + SpecialType.System_SByte => ((sbyte)value!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Byte => ((byte)value!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Int16 => ((short)value!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_UInt16 => ((ushort)value!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_Int32 => ((int)value!).ToString(CultureInfo.InvariantCulture), + SpecialType.System_UInt32 => ((uint)value!).ToString(CultureInfo.InvariantCulture) + "u", + SpecialType.System_Int64 => ((long)value!).ToString(CultureInfo.InvariantCulture) + "L", + SpecialType.System_UInt64 => ((ulong)value!).ToString(CultureInfo.InvariantCulture) + "UL", + _ => "0", + }; + } +} diff --git a/src/GeneratedEndpoints/GeneratedEndpoints.csproj b/src/GeneratedEndpoints/GeneratedEndpoints.csproj index 55bf32f..ca27350 100644 --- a/src/GeneratedEndpoints/GeneratedEndpoints.csproj +++ b/src/GeneratedEndpoints/GeneratedEndpoints.csproj @@ -20,12 +20,12 @@ - + all runtime; build; native; contentfiles; analyzers; buildtransitive - + all @@ -36,9 +36,9 @@ GeneratedEndpoints - 10.0.0 - 10.0.0.0 - 10.0.0.0 + 10.0.1 + 10.0.1.0 + 10.0.1.0 en-US false diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.Types.cs b/src/GeneratedEndpoints/MinimalApiGenerator.Types.cs deleted file mode 100644 index 8ec7a24..0000000 --- a/src/GeneratedEndpoints/MinimalApiGenerator.Types.cs +++ /dev/null @@ -1,234 +0,0 @@ -using GeneratedEndpoints.Common; -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.Text; - -namespace GeneratedEndpoints; - -public sealed partial class MinimalApiGenerator -{ - private readonly record struct HttpAttributeDefinition( - string Name, - string FullyQualifiedName, - string Hint, - string Verb, - SourceText SourceText - ); - - private readonly record struct RequestHandler( - RequestHandlerClass Class, - RequestHandlerMethod Method, - string HttpMethod, - string Pattern, - EndpointConfiguration Configuration - ); - - private readonly record struct RequestHandlerClass( - string Name, - bool IsStatic, - bool HasConfigureMethod, - bool ConfigureMethodAcceptsServiceProvider, - string? MapGroupPattern, - string? MapGroupBuilderIdentifier, - EndpointConfiguration Configuration - ); - - private readonly record struct EndpointConfiguration( - RequestHandlerMetadata Metadata, - bool RequireAuthorization, - EquatableImmutableArray? AuthorizationPolicies, - bool DisableAntiforgery, - bool AllowAnonymous, - bool RequireCors, - string? CorsPolicyName, - EquatableImmutableArray? RequiredHosts, - bool RequireRateLimiting, - string? RateLimitingPolicyName, - EquatableImmutableArray? EndpointFilterTypes, - bool ShortCircuit, - bool DisableValidation, - bool DisableRequestTimeout, - bool WithRequestTimeout, - string? RequestTimeoutPolicyName, - int? Order, - string? EndpointGroupName - ); - - private readonly record struct RequestHandlerMethod(string Name, bool IsStatic, bool IsAwaitable, EquatableImmutableArray Parameters); - - private readonly record struct RequestHandlerMetadata( - string? Name, - string? DisplayName, - string? Summary, - string? Description, - EquatableImmutableArray? Tags, - EquatableImmutableArray? Accepts, - EquatableImmutableArray? Produces, - EquatableImmutableArray? ProducesProblem, - EquatableImmutableArray? ProducesValidationProblem, - bool ExcludeFromDescription - ); - - private readonly record struct AcceptsMetadata( - string RequestType, - string ContentType, - EquatableImmutableArray? AdditionalContentTypes, - bool IsOptional - ); - - private readonly record struct ProducesMetadata( - string ResponseType, - int StatusCode, - string? ContentType, - EquatableImmutableArray? AdditionalContentTypes - ); - - private readonly record struct ProducesProblemMetadata(int StatusCode, string? ContentType, EquatableImmutableArray? AdditionalContentTypes); - - private readonly record struct ProducesValidationProblemMetadata( - int StatusCode, - string? ContentType, - EquatableImmutableArray? AdditionalContentTypes - ); - - private readonly record struct Parameter(string Name, string Type, string BindingPrefix); - - private readonly record struct ConfigureMethodDetails(bool HasConfigureMethod, bool ConfigureMethodAcceptsServiceProvider); - - private struct EndpointAttributeState - { - public EquatableImmutableArray? Tags; - public bool? RequireAuthorization; - public EquatableImmutableArray? AuthorizationPolicies; - public bool? DisableAntiforgery; - public bool? AllowAnonymous; - public bool? ExcludeFromDescription; - public List? Accepts; - public List? Produces; - public List? ProducesProblem; - public List? ProducesValidationProblem; - public bool? RequireCors; - public string? CorsPolicyName; - public EquatableImmutableArray? RequiredHosts; - public bool? RequireRateLimiting; - public string? RateLimitingPolicyName; - public List? EndpointFilters; - public HashSet? EndpointFilterSet; - public bool HasAllowAnonymousAttribute; - public bool HasRequireAuthorizationAttribute; - public bool? ShortCircuit; - public bool? DisableValidation; - public bool? DisableRequestTimeout; - public bool? WithRequestTimeout; - public string? RequestTimeoutPolicyName; - public int? Order; - public string? EndpointGroupName; - public string? Summary; - } - - private enum GeneratedAttributeKind - { - None = 0, - ShortCircuit, - DisableValidation, - DisableRequestTimeout, - RequestTimeout, - Order, - MapGroup, - Summary, - Accepts, - ProducesResponse, - RequireAuthorization, - RequireCors, - RequireHost, - RequireRateLimiting, - EndpointFilter, - DisableAntiforgery, - ProducesProblem, - ProducesValidationProblem, - } - - private enum BindingSource - { - None = 0, - FromRoute = 1, - FromQuery = 2, - FromHeader = 3, - FromBody = 4, - FromForm = 5, - FromServices = 6, - FromKeyedServices = 7, - AsParameters = 8, - } - - private sealed class RequestHandlerComparer : IComparer - { - public static RequestHandlerComparer Instance { get; } = new(); - - public int Compare(RequestHandler x, RequestHandler y) - { - var comparison = string.Compare(x.Class.Name, y.Class.Name, StringComparison.Ordinal); - if (comparison != 0) - return comparison; - - comparison = string.Compare(x.Method.Name, y.Method.Name, StringComparison.Ordinal); - if (comparison != 0) - return comparison; - - comparison = string.Compare(x.HttpMethod, y.HttpMethod, StringComparison.Ordinal); - if (comparison != 0) - return comparison; - - return string.Compare(x.Pattern, y.Pattern, StringComparison.Ordinal); - } - } - - private sealed class CompilationTypeCache(Compilation compilation) - { - public INamedTypeSymbol? EndpointConventionBuilderSymbol { get; } = - compilation.GetTypeByMetadataName("Microsoft.AspNetCore.Builder.IEndpointConventionBuilder"); - - public INamedTypeSymbol? ServiceProviderSymbol { get; } = compilation.GetTypeByMetadataName("System.IServiceProvider"); - } - - private sealed class RequestHandlerClassCacheEntry - { - private readonly object _lock = new(); - private RequestHandlerClass _value; - private bool _initialized; - - public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, CompilationTypeCache compilationCache, CancellationToken cancellationToken) - { - if (_initialized) - return _value; - - lock (_lock) - { - if (_initialized) - return _value; - - cancellationToken.ThrowIfCancellationRequested(); - - var name = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var isStatic = classSymbol.IsStatic; - var configureMethodDetails = GetConfigureMethodDetails(classSymbol, compilationCache.EndpointConventionBuilderSymbol, - compilationCache.ServiceProviderSymbol, cancellationToken - ); - - var mapGroupPattern = GetMapGroupPattern(classSymbol); - var mapGroupIdentifier = mapGroupPattern is null ? null : GetMapGroupIdentifier(name); - var classConfiguration = GetEndpointConfiguration(classSymbol.GetAttributes(), null, null, null, false); - - _value = new RequestHandlerClass(name, isStatic, configureMethodDetails.HasConfigureMethod, - configureMethodDetails.ConfigureMethodAcceptsServiceProvider, mapGroupPattern, mapGroupIdentifier, classConfiguration - ); - _initialized = true; - return _value; - } - } - } - - private sealed class GeneratedAttributeKindCacheEntry(GeneratedAttributeKind kind) - { - public GeneratedAttributeKind Kind { get; } = kind; - } -} diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index c5ebad8..00181d4 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -1,27 +1,30 @@ -using System.Buffers; using System.Collections.Immutable; -using System.ComponentModel; -using System.Diagnostics.CodeAnalysis; -using System.Globalization; -using System.Text; +using System.Runtime.CompilerServices; using GeneratedEndpoints.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Text; +using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints; +// ReSharper disable ForCanBeConvertedToForeach +// ReSharper disable LoopCanBeConvertedToQuery +// Do not refactor, use for loop to avoid allocations. + [Generator] -public sealed partial class MinimalApiGenerator : IIncrementalGenerator +public sealed class MinimalApiGenerator : IIncrementalGenerator { + private static readonly ConditionalWeakTable RequestHandlerClassCache = new(); + public void Initialize(IncrementalGeneratorInitializationContext context) { context.RegisterPostInitializationOutput(RegisterAttributes); var requestHandlerProviders = ImmutableArray.CreateBuilder>>(HttpAttributeDefinitions.Length); - foreach (var definition in HttpAttributeDefinitions) + for (var index = 0; index < HttpAttributeDefinitions.Length; index++) { + var definition = HttpAttributeDefinitions[index]; var handlers = context.SyntaxProvider .ForAttributeWithMetadataName(definition.FullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) .WhereNotNull() @@ -30,20 +33,12 @@ public void Initialize(IncrementalGeneratorInitializationContext context) requestHandlerProviders.Add(handlers); } - var requestHandlers = CombineRequestHandlers(requestHandlerProviders.MoveToImmutable()); + var requestHandlers = CombineRequestHandlers(requestHandlerProviders.MoveToImmutable()) + .Select((x, _) => x.ToEquatableImmutableArray()); context.RegisterSourceOutput(requestHandlers, GenerateSource); } - private static HttpAttributeDefinition CreateHttpAttributeDefinition(string attributeName, string verb, bool allowOptionalPattern = false) - { - var fullyQualifiedName = $"{AttributesNamespace}.{attributeName}"; - var hint = $"{fullyQualifiedName}.gs.cs"; - var summaryVerb = verb == FallbackHttpMethod ? "fallback" : verb; - var source = GenerateHttpAttributeSource(AttributesNamespace, attributeName, summaryVerb, allowOptionalPattern); - return new HttpAttributeDefinition(attributeName, fullyQualifiedName, hint, verb, SourceText.From(source, Encoding.UTF8)); - } - private static IncrementalValueProvider> CombineRequestHandlers( ImmutableArray>> handlerProviders ) @@ -66,58 +61,23 @@ private static void RegisterAttributes(IncrementalGeneratorPostInitializationCon foreach (var definition in HttpAttributeDefinitions) context.AddSource(definition.Hint, definition.SourceText); - context.AddSource(RequireAuthorizationAttributeHint, RequireAuthorizationAttributeSourceText); - context.AddSource(RequireCorsAttributeHint, RequireCorsAttributeSourceText); - context.AddSource(RequireRateLimitingAttributeHint, RequireRateLimitingAttributeSourceText); - context.AddSource(RequireHostAttributeHint, RequireHostAttributeSourceText); + context.AddSource(AcceptsAttributeHint, AcceptsAttributeSourceText); context.AddSource(DisableAntiforgeryAttributeHint, DisableAntiforgeryAttributeSourceText); - context.AddSource(ShortCircuitAttributeHint, ShortCircuitAttributeSourceText); context.AddSource(DisableRequestTimeoutAttributeHint, DisableRequestTimeoutAttributeSourceText); context.AddSource(DisableValidationAttributeHint, DisableValidationAttributeSourceText); - context.AddSource(RequestTimeoutAttributeHint, RequestTimeoutAttributeSourceText); - context.AddSource(OrderAttributeHint, OrderAttributeSourceText); - context.AddSource(MapGroupAttributeHint, MapGroupAttributeSourceText); - context.AddSource(SummaryAttributeHint, SummaryAttributeSourceText); - context.AddSource(AcceptsAttributeHint, AcceptsAttributeSourceText); context.AddSource(EndpointFilterAttributeHint, EndpointFilterAttributeSourceText); - context.AddSource(ProducesResponseAttributeHint, ProducesResponseAttributeSourceText); + context.AddSource(MapGroupAttributeHint, MapGroupAttributeSourceText); + context.AddSource(OrderAttributeHint, OrderAttributeSourceText); context.AddSource(ProducesProblemAttributeHint, ProducesProblemAttributeSourceText); + context.AddSource(ProducesResponseAttributeHint, ProducesResponseAttributeSourceText); context.AddSource(ProducesValidationProblemAttributeHint, ProducesValidationProblemAttributeSourceText); - } - - private static string GenerateHttpAttributeSource(string attributesNamespace, string attributeName, string summaryVerb, bool allowOptionalPattern = false) - { - return $$""" - {{FileHeader}} - - namespace {{attributesNamespace}}; - - /// - /// Identifies a method as an HTTP {{summaryVerb}} minimal API endpoint with the specified route pattern. - /// - [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = false)] - internal sealed class {{attributeName}} : global::System.Attribute - { - /// - /// Gets the route pattern for the endpoint. - /// - public string{{(allowOptionalPattern ? "?" : "")}} Pattern { get; } - - /// - /// Gets or sets the endpoint name. - /// - public string? Name { get; init; } - - /// - /// Initializes a new instance of the class. - /// - /// The route pattern for the endpoint. - public {{attributeName}}([global::System.Diagnostics.CodeAnalysis.StringSyntax("Route")] string{{(allowOptionalPattern ? "?" : "")}} pattern{{(allowOptionalPattern ? " = null" : "")}}) - { - Pattern = pattern; - } - } - """; + context.AddSource(RequestTimeoutAttributeHint, RequestTimeoutAttributeSourceText); + context.AddSource(RequireAuthorizationAttributeHint, RequireAuthorizationAttributeSourceText); + context.AddSource(RequireCorsAttributeHint, RequireCorsAttributeSourceText); + context.AddSource(RequireHostAttributeHint, RequireHostAttributeSourceText); + context.AddSource(RequireRateLimitingAttributeHint, RequireRateLimitingAttributeSourceText); + context.AddSource(ShortCircuitAttributeHint, ShortCircuitAttributeSourceText); + context.AddSource(SummaryAttributeHint, SummaryAttributeSourceText); } private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToken cancellationToken) @@ -131,1861 +91,135 @@ private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToke { cancellationToken.ThrowIfCancellationRequested(); - if (context.TargetSymbol is not IMethodSymbol requestHandlerMethodSymbol) + if (context.TargetSymbol is not IMethodSymbol methodSymbol) return null; var attribute = context.Attributes[0]; - var requestHandlerClass = GetRequestHandlerClass(requestHandlerMethodSymbol, context.SemanticModel.Compilation, cancellationToken); + var requestHandlerClass = GetRequestHandlerClass(methodSymbol, cancellationToken); if (requestHandlerClass is null) return null; - var requestHandlerMethod = GetRequestHandlerMethod(requestHandlerMethodSymbol, cancellationToken); - - var (httpMethod, pattern, name) = GetRequestHandlerAttribute(attribute, cancellationToken); + var requestHandlerMethod = GetRequestHandlerMethod(methodSymbol, cancellationToken); - var (displayName, description) = GetDisplayAndDescriptionAttributes(requestHandlerMethodSymbol); + var (httpMethod, pattern, name) = GetRequestHandlerAttribute(methodSymbol, attribute, cancellationToken); - name ??= RemoveAsyncSuffix(requestHandlerMethod.Name); - - var methodConfiguration = GetEndpointConfiguration(requestHandlerMethodSymbol.GetAttributes(), name, displayName, description, true); - - var requestHandler = new RequestHandler(requestHandlerClass.Value, requestHandlerMethod, httpMethod, pattern, methodConfiguration); + var requestHandler = new RequestHandler + { + Class = requestHandlerClass.Value, + Method = requestHandlerMethod, + HttpMethod = httpMethod, + Pattern = pattern, + Name = name, + }; return requestHandler; } - private static string RemoveAsyncSuffix(string methodName) - { - if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) - return methodName[..^AsyncSuffix.Length]; - - return methodName; - } - - private static ( string HttpMethod, string Pattern, string? Name ) GetRequestHandlerAttribute(AttributeData attribute, CancellationToken cancellationToken) + private static (string HttpMethod, string Pattern, string? Name) GetRequestHandlerAttribute( + IMethodSymbol methodSymbol, + AttributeData attribute, + CancellationToken cancellationToken + ) { cancellationToken.ThrowIfCancellationRequested(); var attributeName = attribute.AttributeClass?.Name ?? ""; - var httpMethod = HttpAttributeDefinitionsByName.TryGetValue(attributeName, out var definition) ? definition.Verb : ""; - - var pattern = (attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : "") ?? ""; - - string? name = null; - foreach (var namedArg in attribute.NamedArguments) - { - switch (namedArg.Key) - { - case NameAttributeNamedParameter: - { - var value = namedArg.Value.Value as string; - name = string.IsNullOrWhiteSpace(value) ? null : value!.Trim(); - break; - } - } - } + var pattern = attribute.GetConstructorStringValue() ?? ""; + var name = attribute.GetNamedStringValue(NameAttributeNamedParameter); + name ??= RemoveAsyncSuffix(methodSymbol.Name); return (httpMethod, pattern, name); } - private static (string? DisplayName, string? Description) GetDisplayAndDescriptionAttributes(IMethodSymbol methodSymbol) - { - string? displayName = null; - string? description = null; - - foreach (var attribute in methodSymbol.GetAttributes()) - { - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - if (IsAttribute(attributeClass, nameof(DisplayNameAttribute), ComponentModelNamespaceParts)) - { - displayName = NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : null); - - continue; - } - - if (IsAttribute(attributeClass, nameof(DescriptionAttribute), ComponentModelNamespaceParts)) - description = NormalizeOptionalString(attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0].Value as string : null); - } - - return (displayName, description); - } - - private static EndpointConfiguration GetEndpointConfiguration( - ImmutableArray attributes, - string? name, - string? displayName, - string? description, - bool enforceMethodRequireAuthorizationRules - ) - { - var state = new EndpointAttributeState(); - - GetAdditionalRequestHandlerAttributeValues(attributes, ref state); - - if (enforceMethodRequireAuthorizationRules && state is { HasRequireAuthorizationAttribute: true, HasAllowAnonymousAttribute: false }) - state.AllowAnonymous = false; - - var metadata = new RequestHandlerMetadata(name, displayName, state.Summary, description, state.Tags, ToEquatableOrNull(state.Accepts), - ToEquatableOrNull(state.Produces), ToEquatableOrNull(state.ProducesProblem), ToEquatableOrNull(state.ProducesValidationProblem), - state.ExcludeFromDescription ?? false - ); - - var withRequestTimeout = state.WithRequestTimeout ?? false; - var requestTimeoutPolicyName = withRequestTimeout ? state.RequestTimeoutPolicyName : null; - - return new EndpointConfiguration(metadata, state.RequireAuthorization ?? false, state.AuthorizationPolicies, state.DisableAntiforgery ?? false, - state.AllowAnonymous ?? false, state.RequireCors ?? false, state.CorsPolicyName, state.RequiredHosts, state.RequireRateLimiting ?? false, - state.RateLimitingPolicyName, ToEquatableOrNull(state.EndpointFilters), state.ShortCircuit ?? false, state.DisableValidation ?? false, - state.DisableRequestTimeout ?? false, withRequestTimeout, requestTimeoutPolicyName, state.Order, state.EndpointGroupName - ); - } - - private static void GetAdditionalRequestHandlerAttributeValues(ImmutableArray attributes, ref EndpointAttributeState state) - { - ref var tags = ref state.Tags; - ref var requireAuthorization = ref state.RequireAuthorization; - ref var authorizationPolicies = ref state.AuthorizationPolicies; - ref var disableAntiforgery = ref state.DisableAntiforgery; - ref var allowAnonymous = ref state.AllowAnonymous; - ref var excludeFromDescription = ref state.ExcludeFromDescription; - ref var accepts = ref state.Accepts; - ref var produces = ref state.Produces; - ref var producesProblem = ref state.ProducesProblem; - ref var producesValidationProblem = ref state.ProducesValidationProblem; - ref var requireCors = ref state.RequireCors; - ref var corsPolicyName = ref state.CorsPolicyName; - ref var requiredHosts = ref state.RequiredHosts; - ref var requireRateLimiting = ref state.RequireRateLimiting; - ref var rateLimitingPolicyName = ref state.RateLimitingPolicyName; - ref var endpointFilters = ref state.EndpointFilters; - ref var endpointFilterSet = ref state.EndpointFilterSet; - ref var hasAllowAnonymousAttribute = ref state.HasAllowAnonymousAttribute; - ref var hasRequireAuthorizationAttribute = ref state.HasRequireAuthorizationAttribute; - ref var shortCircuit = ref state.ShortCircuit; - ref var disableValidation = ref state.DisableValidation; - ref var disableRequestTimeout = ref state.DisableRequestTimeout; - ref var withRequestTimeout = ref state.WithRequestTimeout; - ref var requestTimeoutPolicyName = ref state.RequestTimeoutPolicyName; - ref var order = ref state.Order; - ref var endpointGroupName = ref state.EndpointGroupName; - ref var summary = ref state.Summary; - - foreach (var attribute in attributes) - { - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - switch (GetGeneratedAttributeKind(attributeClass)) - { - case GeneratedAttributeKind.ShortCircuit: - shortCircuit = true; - continue; - case GeneratedAttributeKind.DisableValidation: - disableValidation = true; - continue; - case GeneratedAttributeKind.DisableRequestTimeout: - disableRequestTimeout = true; - withRequestTimeout = false; - requestTimeoutPolicyName = null; - continue; - case GeneratedAttributeKind.RequestTimeout: - { - disableRequestTimeout = false; - withRequestTimeout = true; - - string? policyName = null; - if (attribute.ConstructorArguments.Length > 0) - policyName = attribute.ConstructorArguments[0].Value as string; - - policyName ??= GetNamedStringValue(attribute, PolicyNameAttributeNamedParameter); - requestTimeoutPolicyName = NormalizeOptionalString(policyName); - continue; - } - case GeneratedAttributeKind.Order: - if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int orderValue) - order = orderValue; - continue; - case GeneratedAttributeKind.MapGroup: - { - var groupName = GetNamedStringValue(attribute, NameAttributeNamedParameter); - if (!string.IsNullOrEmpty(groupName)) - endpointGroupName = groupName; - continue; - } - case GeneratedAttributeKind.Summary: - if (attribute.ConstructorArguments.Length > 0) - { - var summaryValue = NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string); - if (!string.IsNullOrEmpty(summaryValue)) - summary = summaryValue; - } - continue; - case GeneratedAttributeKind.Accepts: - TryAddAcceptsMetadata(attribute, attributeClass, ref accepts); - continue; - case GeneratedAttributeKind.ProducesResponse: - TryAddProducesMetadata(attribute, attributeClass, ref produces); - continue; - case GeneratedAttributeKind.RequireAuthorization: - requireAuthorization = true; - hasRequireAuthorizationAttribute = true; - if (attribute.ConstructorArguments.Length == 1) - { - var arg = attribute.ConstructorArguments[0]; - MergeInto(ref authorizationPolicies, arg.Values); - } - - continue; - case GeneratedAttributeKind.RequireCors: - requireCors = true; - corsPolicyName = attribute.ConstructorArguments.Length > 0 - ? NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string) - : null; - continue; - case GeneratedAttributeKind.RequireHost: - if (attribute.ConstructorArguments.Length == 1) - { - var arg = attribute.ConstructorArguments[0]; - if (arg is { Kind: TypedConstantKind.Array, Values.Length: > 0 }) - MergeInto(ref requiredHosts, arg.Values); - else if (arg.Value is string singleHost && !string.IsNullOrWhiteSpace(singleHost)) - MergeInto(ref requiredHosts, [singleHost.Trim()]); - } - - continue; - case GeneratedAttributeKind.RequireRateLimiting: - { - var policyName = attribute.ConstructorArguments.Length > 0 - ? NormalizeOptionalString(attribute.ConstructorArguments[0].Value as string) - : null; - - if (!string.IsNullOrEmpty(policyName)) - { - requireRateLimiting = true; - rateLimitingPolicyName = policyName; - } - - continue; - } - case GeneratedAttributeKind.EndpointFilter: - TryAddEndpointFilter(attribute, attributeClass, ref endpointFilters, ref endpointFilterSet); - continue; - case GeneratedAttributeKind.DisableAntiforgery: - disableAntiforgery = true; - continue; - case GeneratedAttributeKind.ProducesProblem: - { - var statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesProblemStatusCode - ? producesProblemStatusCode - : 500; - var contentType = attribute.ConstructorArguments.Length > 1 - ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) - : null; - var additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; - - var producesProblemList = producesProblem ??= []; - producesProblemList.Add(new ProducesProblemMetadata(statusCode, contentType, additionalContentTypes)); - continue; - } - case GeneratedAttributeKind.ProducesValidationProblem: - { - var statusCode = - attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesValidationProblemStatusCode - ? producesValidationProblemStatusCode - : 400; - var contentType = attribute.ConstructorArguments.Length > 1 - ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) - : null; - var additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; - - var producesValidationProblemList = producesValidationProblem ??= []; - producesValidationProblemList.Add(new ProducesValidationProblemMetadata(statusCode, contentType, additionalContentTypes)); - continue; - } - } - - if (IsAttribute(attributeClass, AllowAnonymousAttributeName, AspNetCoreAuthorizationNamespaceParts)) - { - allowAnonymous = true; - hasAllowAnonymousAttribute = true; - continue; - } - - if (IsAttribute(attributeClass, "TagsAttribute", AspNetCoreHttpNamespaceParts)) - { - if (attribute.ConstructorArguments.Length > 0) - { - var arg = attribute.ConstructorArguments[0]; - MergeInto(ref tags, arg.Values); - } - - continue; - } - - if (IsAttribute(attributeClass, "ExcludeFromDescriptionAttribute", AspNetCoreRoutingNamespaceParts)) - excludeFromDescription = true; - } - } - - private static void MergeInto(ref EquatableImmutableArray? target, IEnumerable values) - { - var merged = MergeUnion(target, values); - target = merged.Count > 0 ? merged : null; - } - - private static void MergeInto(ref EquatableImmutableArray? target, ImmutableArray values) - { - if (values.IsDefaultOrEmpty) - return; - - List? normalized = null; - foreach (var value in values) - { - if (value.Value is not string stringValue) - continue; - - var trimmed = NormalizeOptionalString(stringValue); - if (trimmed is not { Length: > 0 }) - continue; - - normalized ??= new List(values.Length); - normalized.Add(trimmed); - } - - if (normalized is { Count: > 0 }) - MergeInto(ref target, normalized); - } - - private static EquatableImmutableArray? ToEquatableOrNull(List? values) - { - return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null; - } - - private static string NormalizeRequiredContentType(string? contentType, string defaultValue) - { - return string.IsNullOrWhiteSpace(contentType) ? defaultValue : contentType!.Trim(); - } - - private static string? NormalizeOptionalContentType(string? contentType) - { - return string.IsNullOrWhiteSpace(contentType) ? null : contentType!.Trim(); - } - - private static string? NormalizeOptionalString(string? value) - { - return string.IsNullOrWhiteSpace(value) ? null : value!.Trim(); - } - - [SuppressMessage("Major Code Smell", "S3398:Move this method into a class of its own", Justification = "Shared helper for multiple caching paths.")] - private static string? GetMapGroupPattern(INamedTypeSymbol classSymbol) + private static string RemoveAsyncSuffix(string methodName) { - foreach (var attribute in classSymbol.GetAttributes()) - { - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - if (GetGeneratedAttributeKind(attributeClass) != GeneratedAttributeKind.MapGroup) - continue; - - if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is string pattern) - return pattern.Trim(); - } + if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) + return methodName[..^AsyncSuffix.Length]; - return null; + return methodName; } - [SuppressMessage("Major Code Smell", "S3398:Move this method into a class of its own", Justification = "Shared helper for multiple caching paths.")] - private static string GetMapGroupIdentifier(string className) + private static RequestHandlerClass? GetRequestHandlerClass(IMethodSymbol methodSymbol, CancellationToken cancellationToken) { - if (className.StartsWith(GlobalPrefix, StringComparison.Ordinal)) - className = className.Substring(GlobalPrefix.Length); - - var builder = StringBuilderPool.Get(className.Length + 8); - builder.Append('_'); - - foreach (var character in className) - builder.Append(char.IsLetterOrDigit(character) ? character : '_'); - - builder.Append("_Group"); - return StringBuilderPool.ToStringAndReturn(builder); - } + cancellationToken.ThrowIfCancellationRequested(); - private static EquatableImmutableArray? GetStringArrayValues(TypedConstant typedConstant) - { - if (typedConstant.Kind != TypedConstantKind.Array || typedConstant.Values.IsDefaultOrEmpty) + var classSymbol = methodSymbol.ContainingType; + if (classSymbol.TypeKind != TypeKind.Class) return null; - var builder = ImmutableArray.CreateBuilder(typedConstant.Values.Length); - foreach (var value in typedConstant.Values) - { - if (value.Value is string s && !string.IsNullOrWhiteSpace(s)) - builder.Add(s.Trim()); - } - - return builder.Count > 0 ? builder.ToEquatableImmutable() : null; - } - - private static GeneratedAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) - { - var definition = attributeClass.OriginalDefinition; - var cacheEntry = - GeneratedAttributeKindCache.GetValue(definition, static def => new GeneratedAttributeKindCacheEntry(GetGeneratedAttributeKindCore(def))); - - return cacheEntry.Kind; - } - - private static GeneratedAttributeKind GetGeneratedAttributeKindCore(INamedTypeSymbol definition) - { - if (!IsInNamespace(definition.ContainingNamespace, AttributesNamespaceParts)) - return GeneratedAttributeKind.None; - - return definition.Name switch - { - ShortCircuitAttributeName => GeneratedAttributeKind.ShortCircuit, - DisableValidationAttributeName => GeneratedAttributeKind.DisableValidation, - DisableRequestTimeoutAttributeName => GeneratedAttributeKind.DisableRequestTimeout, - RequestTimeoutAttributeName => GeneratedAttributeKind.RequestTimeout, - OrderAttributeName => GeneratedAttributeKind.Order, - MapGroupAttributeName => GeneratedAttributeKind.MapGroup, - SummaryAttributeName => GeneratedAttributeKind.Summary, - AcceptsAttributeName => GeneratedAttributeKind.Accepts, - ProducesResponseAttributeName => GeneratedAttributeKind.ProducesResponse, - RequireAuthorizationAttributeName => GeneratedAttributeKind.RequireAuthorization, - RequireCorsAttributeName => GeneratedAttributeKind.RequireCors, - RequireHostAttributeName => GeneratedAttributeKind.RequireHost, - RequireRateLimitingAttributeName => GeneratedAttributeKind.RequireRateLimiting, - EndpointFilterAttributeName => GeneratedAttributeKind.EndpointFilter, - DisableAntiforgeryAttributeName => GeneratedAttributeKind.DisableAntiforgery, - ProducesProblemAttributeName => GeneratedAttributeKind.ProducesProblem, - ProducesValidationProblemAttributeName => GeneratedAttributeKind.ProducesValidationProblem, - _ => GeneratedAttributeKind.None, - }; - } - - private static bool IsAttribute(INamedTypeSymbol attributeClass, string attributeName, string[] namespaceParts) - { - var definition = attributeClass.OriginalDefinition; - return definition.Name == attributeName && IsInNamespace(definition.ContainingNamespace, namespaceParts); - } - - private static BindingSource GetBindingSourceFromAttributeClass(INamedTypeSymbol attributeClass) - { - var definition = attributeClass.OriginalDefinition; - var namespaceSymbol = definition.ContainingNamespace; - - return definition.Name switch - { - "FromRouteAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromRoute, - "FromQueryAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromQuery, - "FromHeaderAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromHeader, - "FromBodyAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromBody, - "FromFormAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromForm, - "FromServicesAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromServices, - "FromKeyedServicesAttribute" when IsInNamespace(namespaceSymbol, ExtensionsDependencyInjectionNamespaceParts) - => BindingSource.FromKeyedServices, - "AsParametersAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreHttpNamespaceParts) => BindingSource.AsParameters, - _ => BindingSource.None, - }; - } - - private static bool IsInNamespace(INamespaceSymbol? namespaceSymbol, string[] namespaceParts) - { - for (var i = namespaceParts.Length - 1; i >= 0; i--) - { - if (namespaceSymbol is null || namespaceSymbol.Name != namespaceParts[i]) - return false; - - namespaceSymbol = namespaceSymbol.ContainingNamespace; - } + var cacheEntry = RequestHandlerClassCache.GetValue(classSymbol, static _ => new RequestHandlerClassCacheEntry()); + var requestHandlerClass = cacheEntry.GetOrCreate(classSymbol, cancellationToken); - return namespaceSymbol is null || namespaceSymbol.IsGlobalNamespace; + return requestHandlerClass; } - private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? accepts) + private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol methodSymbol, CancellationToken cancellationToken) { - string? requestType; - string contentType; - EquatableImmutableArray? additionalContentTypes; - var isOptional = GetNamedBoolValue(attribute, IsOptionalAttributeNamedParameter); + cancellationToken.ThrowIfCancellationRequested(); - if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - { - requestType = attributeClass.TypeArguments[0] - .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - contentType = attribute.ConstructorArguments.Length > 0 - ? NormalizeRequiredContentType(attribute.ConstructorArguments[0].Value as string, "application/json") - : "application/json"; - additionalContentTypes = attribute.ConstructorArguments.Length > 1 ? GetStringArrayValues(attribute.ConstructorArguments[1]) : null; - } - else if (GetNamedTypeSymbol(attribute, RequestTypeAttributeNamedParameter) is { } requestTypeSymbol) - { - requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - contentType = attribute.ConstructorArguments.Length > 0 - ? NormalizeRequiredContentType(attribute.ConstructorArguments[0].Value as string, "application/json") - : "application/json"; - additionalContentTypes = attribute.ConstructorArguments.Length > 1 ? GetStringArrayValues(attribute.ConstructorArguments[1]) : null; - } - else - { - return; - } + var name = methodSymbol.Name; + var isStatic = methodSymbol.IsStatic; + var parameters = methodSymbol.GetParameters(cancellationToken); + var configuration = EndpointConfigurationFactory.Create(methodSymbol); + var requestHandlerMethod = new RequestHandlerMethod(name, isStatic, parameters, configuration); - var acceptsList = accepts ??= []; - acceptsList.Add(new AcceptsMetadata(requestType, contentType, additionalContentTypes, isOptional)); + return requestHandlerMethod; } - private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? produces) + private static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) { - string? responseType; - int statusCode; - string? contentType; - EquatableImmutableArray? additionalContentTypes; + context.CancellationToken.ThrowIfCancellationRequested(); - if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - { - responseType = attributeClass.TypeArguments[0] - .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode - ? producesStatusCode - : 200; - contentType = attribute.ConstructorArguments.Length > 1 ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) : null; - additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; - } - else if (GetNamedTypeSymbol(attribute, ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol) - { - responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - statusCode = attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is int producesStatusCode - ? producesStatusCode - : 200; - contentType = attribute.ConstructorArguments.Length > 1 ? NormalizeOptionalContentType(attribute.ConstructorArguments[1].Value as string) : null; - additionalContentTypes = attribute.ConstructorArguments.Length > 2 ? GetStringArrayValues(attribute.ConstructorArguments[2]) : null; - } - else - { - return; - } + var normalized = NormalizeRequestHandlers(requestHandlers); - var producesList = produces ??= []; - producesList.Add(new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes)); + AddEndpointHandlersGenerator.GenerateSource(context, normalized); + UseEndpointHandlersGenerator.GenerateSource(context, normalized); } - private static void TryAddEndpointFilter( - AttributeData attribute, - INamedTypeSymbol attributeClass, - ref List? endpointFilters, - ref HashSet? endpointFilterSet) + private static EquatableImmutableArray NormalizeRequestHandlers(EquatableImmutableArray requestHandlers) { - if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - { - TryAddEndpointFilterType(attributeClass.TypeArguments[0], ref endpointFilters, ref endpointFilterSet); - return; - } + if (requestHandlers.Count <= 1) + return requestHandlers; - if (attribute.ConstructorArguments.Length == 0) - return; + requestHandlers.SortInPlace(RequestHandlerComparer.Instance); + ResolveEndpointNameCollisions(requestHandlers); - if (attribute.ConstructorArguments[0].Value is ITypeSymbol filterTypeSymbol) - TryAddEndpointFilterType(filterTypeSymbol, ref endpointFilters, ref endpointFilterSet); + return requestHandlers; } - private static void TryAddEndpointFilterType( - ITypeSymbol? typeSymbol, - ref List? endpointFilters, - ref HashSet? endpointFilterSet) + private static void ResolveEndpointNameCollisions(EquatableImmutableArray requestHandlers) { - if (typeSymbol is null or ITypeParameterSymbol or IErrorTypeSymbol) - return; - - var displayString = typeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - if (string.IsNullOrWhiteSpace(displayString)) - return; - - endpointFilterSet ??= new HashSet(StringComparer.Ordinal); - if (!endpointFilterSet.Add(displayString)) + var raw = requestHandlers.AsArray(); + if (raw is null) return; - endpointFilters ??= []; - endpointFilters.Add(displayString); - } - - private static ITypeSymbol? GetNamedTypeSymbol(AttributeData attribute, string namedParameter) - { - foreach (var namedArg in attribute.NamedArguments) - { - if (namedArg.Key == namedParameter && namedArg.Value.Value is ITypeSymbol typeSymbol) - return typeSymbol; - } - - return null; - } - - private static bool GetNamedBoolValue(AttributeData attribute, string namedParameter, bool defaultValue = false) - { - foreach (var namedArg in attribute.NamedArguments) - { - if (namedArg.Key == namedParameter && namedArg.Value.Value is bool boolValue) - return boolValue; - } - - return defaultValue; - } - - private static string? GetNamedStringValue(AttributeData attribute, string namedParameter) - { - foreach (var namedArg in attribute.NamedArguments) - { - if (namedArg.Key == namedParameter && namedArg.Value.Value is string stringValue) - return NormalizeOptionalString(stringValue); - } - - return null; - } - - private static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, IEnumerable values) - { - List? list = null; - HashSet? seen = null; - - if (existing is { Count: > 0 }) - { - var count = existing.Value.Count; - list = new List(count + 4); - list.AddRange(existing.Value); - seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); - } + var span = raw.AsSpan(); + var seen = new Dictionary(span.Length, StringComparer.Ordinal); - foreach (var value in values) + for (var index = 0; index < span.Length; index++) { - var normalized = NormalizeOptionalString(value); - if (normalized is not { Length: > 0 }) + ref var handler = ref span[index]; + var name = handler.Name; + if (string.IsNullOrEmpty(name)) continue; + var nonEmptyName = name!; - seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); - if (!seen.Add(normalized)) + if (!seen.TryGetValue(nonEmptyName, out var state)) + { + seen.Add(nonEmptyName, index); continue; + } - list ??= []; + var firstIndex = state >= 0 ? state : ~state; + if (state >= 0) + { + ref var firstHandler = ref span[firstIndex]; + firstHandler.SetFullyQualifiedName(); + seen[nonEmptyName] = ~firstIndex; + } - list.Add(normalized); + handler.SetFullyQualifiedName(); } - - return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; - } - - private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol methodSymbol, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - var name = methodSymbol.Name; - var isStatic = methodSymbol.IsStatic; - var isAwaitable = methodSymbol.ReturnType.IsTask(out _) || methodSymbol.ReturnType.IsValueTask(out _); - var parameters = GetRequestHandlerParameters(methodSymbol, cancellationToken); - - var requestHandlerMethod = new RequestHandlerMethod(name, isStatic, isAwaitable, parameters); - - return requestHandlerMethod; - } - - private static RequestHandlerClass? GetRequestHandlerClass(IMethodSymbol methodSymbol, Compilation compilation, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - var classSymbol = methodSymbol.ContainingType; - if (classSymbol.TypeKind != TypeKind.Class) - return null; - - var typeCache = GetCompilationTypeCache(compilation); - var cacheEntry = RequestHandlerClassCache.GetValue(classSymbol, static _ => new RequestHandlerClassCacheEntry()); - var requestHandlerClass = cacheEntry.GetOrCreate(classSymbol, typeCache, cancellationToken); - return requestHandlerClass; - } - - private static CompilationTypeCache GetCompilationTypeCache(Compilation compilation) - { - return CompilationTypeCaches.GetValue(compilation, static c => new CompilationTypeCache(c)); - } - - [SuppressMessage("Major Code Smell", "S3398:Move this method into a class of its own", Justification = "Shared helper reused by caching infrastructure.")] - private static ConfigureMethodDetails GetConfigureMethodDetails( - INamedTypeSymbol classSymbol, - INamedTypeSymbol? endpointConventionBuilderSymbol, - INamedTypeSymbol? serviceProviderSymbol, - CancellationToken cancellationToken - ) - { - cancellationToken.ThrowIfCancellationRequested(); - - var hasConfigureMethod = false; - var acceptsServiceProvider = false; - foreach (var member in classSymbol.GetMembers(ConfigureMethodName)) - { - cancellationToken.ThrowIfCancellationRequested(); - - if (member is not IMethodSymbol methodSymbol) - continue; - - if (IsConfigureMethod(methodSymbol, endpointConventionBuilderSymbol, serviceProviderSymbol, out var methodAcceptsServiceProvider)) - { - hasConfigureMethod = true; - if (methodAcceptsServiceProvider) - { - acceptsServiceProvider = true; - break; - } - } - } - - return new ConfigureMethodDetails(hasConfigureMethod, acceptsServiceProvider); - } - - private static bool IsConfigureMethod( - IMethodSymbol methodSymbol, - INamedTypeSymbol? endpointConventionBuilderSymbol, - INamedTypeSymbol? serviceProviderSymbol, - out bool acceptsServiceProvider - ) - { - acceptsServiceProvider = false; - - if (!methodSymbol.IsStatic) - return false; - - if (methodSymbol.TypeParameters.Length != 1) - return false; - - if (methodSymbol.Parameters.Length is < 1 or > 2) - return false; - - var builderTypeParameter = methodSymbol.TypeParameters[0]; - var builderParameter = methodSymbol.Parameters[0]; - - if (!SymbolEqualityComparer.Default.Equals(builderParameter.Type, builderTypeParameter)) - return false; - - if (methodSymbol.Parameters.Length == 2) - { - var serviceProviderParameter = methodSymbol.Parameters[1]; - if (!IsServiceProviderParameter(serviceProviderParameter.Type, serviceProviderSymbol)) - return false; - - acceptsServiceProvider = true; - } - - if (!methodSymbol.ReturnsVoid) - return false; - - if (!HasEndpointConventionBuilderConstraint(builderTypeParameter, methodSymbol, endpointConventionBuilderSymbol)) - return false; - - return true; - } - - private static bool IsServiceProviderParameter(ITypeSymbol typeSymbol, INamedTypeSymbol? serviceProviderSymbol) - { - if (serviceProviderSymbol is not null) - return SymbolEqualityComparer.Default.Equals(typeSymbol, serviceProviderSymbol); - - return MatchesServiceProvider(typeSymbol); - } - - private static bool HasEndpointConventionBuilderConstraint( - ITypeParameterSymbol builderTypeParameter, - IMethodSymbol methodSymbol, - INamedTypeSymbol? endpointConventionBuilderSymbol - ) - { - var symbolMatches = builderTypeParameter.ConstraintTypes.Any(constraint => - endpointConventionBuilderSymbol is not null - ? SymbolEqualityComparer.Default.Equals(constraint, endpointConventionBuilderSymbol) - : MatchesEndpointConventionBuilder(constraint) - ); - - if (symbolMatches) - return true; - - return methodSymbol.DeclaringSyntaxReferences - .Select(reference => reference.GetSyntax()) - .OfType() - .SelectMany(methodSyntax => methodSyntax.ConstraintClauses) - .Where(clause => string.Equals(clause.Name.Identifier.ValueText, builderTypeParameter.Name, StringComparison.Ordinal)) - .SelectMany(clause => clause.Constraints.OfType()) - .Any(constraint => IsEndpointConventionBuilderIdentifier(constraint.Type)); - } - - private static bool IsEndpointConventionBuilderIdentifier(TypeSyntax typeSyntax) - { - return typeSyntax switch - { - QualifiedNameSyntax qualified => IsEndpointConventionBuilderIdentifier(qualified.Right), - AliasQualifiedNameSyntax alias => IsEndpointConventionBuilderIdentifier(alias.Name), - SimpleNameSyntax simple => string.Equals(simple.Identifier.ValueText, "IEndpointConventionBuilder", StringComparison.Ordinal), - _ => false, - }; - } - - private static bool MatchesEndpointConventionBuilder(ITypeSymbol typeSymbol) - { - if (typeSymbol is not INamedTypeSymbol namedType) - return false; - - if (!string.Equals(namedType.Name, "IEndpointConventionBuilder", StringComparison.Ordinal)) - return false; - - var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? string.Empty; - return string.Equals(containingNamespace, "Microsoft.AspNetCore.Builder", StringComparison.Ordinal); - } - - private static bool MatchesServiceProvider(ITypeSymbol typeSymbol) - { - if (typeSymbol is not INamedTypeSymbol namedType) - return false; - - if (!string.Equals(namedType.Name, "IServiceProvider", StringComparison.Ordinal)) - return false; - - var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? string.Empty; - return string.Equals(containingNamespace, "System", StringComparison.Ordinal); - } - - private static EquatableImmutableArray GetRequestHandlerParameters(IMethodSymbol methodSymbol, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - var methodParameters = ImmutableArray.CreateBuilder(methodSymbol.Parameters.Length); - foreach (var parameter in methodSymbol.Parameters) - { - cancellationToken.ThrowIfCancellationRequested(); - - var source = BindingSource.None; - TypedConstant? typedKey = null; - string? bindingName = null; - - foreach (var attribute in parameter.GetAttributes()) - { - var attributeClass = attribute.AttributeClass; - if (attributeClass is null) - continue; - - var attributeSource = GetBindingSourceFromAttributeClass(attributeClass); - if (attributeSource == BindingSource.None) - continue; - - source = attributeSource; - switch (attributeSource) - { - case BindingSource.FromRoute: - case BindingSource.FromQuery: - case BindingSource.FromHeader: - case BindingSource.FromForm: - bindingName = GetBindingAttributeName(attribute) ?? bindingName; - break; - case BindingSource.FromKeyedServices: - typedKey = attribute.ConstructorArguments.Length > 0 ? attribute.ConstructorArguments[0] : null; - break; - } - } - - var parameterName = parameter.Name; - var parameterType = parameter.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var key = typedKey.HasValue ? ConstLiteral(typedKey.Value) : null; - var bindingPrefix = GetBindingSourceAttribute(source, key, bindingName); - methodParameters.Add(new Parameter(parameterName, parameterType, bindingPrefix)); - } - - return methodParameters.ToEquatableImmutable(); - } - - private static string? GetBindingAttributeName(AttributeData attribute) - { - foreach (var namedArg in attribute.NamedArguments) - { - if (string.Equals(namedArg.Key, NameAttributeNamedParameter, StringComparison.Ordinal) && namedArg.Value.Value is string namedValue) - { - var normalized = NormalizeBindingName(namedValue); - if (normalized is not null) - return normalized; - } - } - - if (attribute.ConstructorArguments.Length > 0 && attribute.ConstructorArguments[0].Value is string constructorName) - return NormalizeBindingName(constructorName); - - return null; - } - - private static string? NormalizeBindingName(string? value) - { - if (string.IsNullOrWhiteSpace(value)) - return null; - - var trimmed = value!.Trim(); - return trimmed.Length > 0 ? trimmed : null; - } - - private static void GenerateSource(SourceProductionContext context, ImmutableArray requestHandlers) - { - context.CancellationToken.ThrowIfCancellationRequested(); - - var sorted = SortRequestHandlers(requestHandlers); - sorted = EnsureUniqueEndpointNames(sorted); - - GenerateAddEndpointHandlersClass(context, sorted); - GenerateUseEndpointHandlersClass(context, sorted); - } - - private static ImmutableArray SortRequestHandlers(ImmutableArray requestHandlers) - { - if (requestHandlers.Length <= 1) - return requestHandlers; - - var array = requestHandlers.ToArray(); - Array.Sort(array, RequestHandlerComparer.Instance); - return [..array]; - } - - private static ImmutableArray EnsureUniqueEndpointNames(ImmutableArray requestHandlers) - { - var collidingHandlers = GetRequestHandlersWithNameCollisions(requestHandlers); - if (collidingHandlers.IsEmpty) - return requestHandlers; - - var builder = requestHandlers.ToBuilder(); - foreach (var index in collidingHandlers) - { - var handler = builder[index]; - var configuration = handler.Configuration; - var metadata = configuration.Metadata with - { - Name = GetFullyQualifiedMethodDisplayName(handler), - }; - configuration = configuration with - { - Metadata = metadata, - }; - builder[index] = handler with - { - Configuration = configuration, - }; - } - - return builder.MoveToImmutable(); - } - - private static ImmutableArray GetRequestHandlersWithNameCollisions(ImmutableArray requestHandlers) - { - if (requestHandlers.IsDefaultOrEmpty) - return ImmutableArray.Empty; - - var handlerCount = requestHandlers.Length; - var nameToFirstIndex = new Dictionary<(string Name, string Method), int>(handlerCount); - var collisionFlags = ArrayPool.Shared.Rent(handlerCount); - Array.Clear(collisionFlags, 0, handlerCount); - List? collidingIndices = null; - - try - { - for (var index = 0; index < handlerCount; index++) - { - var handler = requestHandlers[index]; - var name = handler.Configuration.Metadata.Name; - if (string.IsNullOrEmpty(name)) - continue; - - var key = (name!, handler.Method.Name); - - if (nameToFirstIndex.TryGetValue(key, out var firstIndex)) - { - MarkCollision(firstIndex); - MarkCollision(index); - } - else - { - nameToFirstIndex.Add(key, index); - } - } - - if (collidingIndices is null || collidingIndices.Count == 0) - return ImmutableArray.Empty; - - collidingIndices.Sort(); - var builder = ImmutableArray.CreateBuilder(collidingIndices.Count); - builder.AddRange(collidingIndices); - return builder.MoveToImmutable(); - } - finally - { - ArrayPool.Shared.Return(collisionFlags); - } - - void MarkCollision(int handlerIndex) - { - if (collisionFlags[handlerIndex]) - return; - - collisionFlags[handlerIndex] = true; - collidingIndices ??= new List(); - collidingIndices.Add(handlerIndex); - } - } - - private static string GetFullyQualifiedMethodDisplayName(RequestHandler requestHandler) - { - var className = requestHandler.Class.Name; - if (className.StartsWith(GlobalPrefix, StringComparison.Ordinal)) - className = className.Substring(GlobalPrefix.Length); - - if (className.IndexOf('+') >= 0) - className = className.Replace('+', '.'); - - return string.Concat(className, ".", requestHandler.Method.Name); - } - - private static void GenerateAddEndpointHandlersClass(SourceProductionContext context, ImmutableArray requestHandlers) - { - context.CancellationToken.ThrowIfCancellationRequested(); - - var nonStaticClassNames = GetDistinctNonStaticClassNames(requestHandlers); - var source = GetAddEndpointHandlersStringBuilder(nonStaticClassNames); - source.AppendLine(FileHeader); - - source.AppendLine(); - - source.AppendLine("using Microsoft.Extensions.DependencyInjection;"); - source.AppendLine("using Microsoft.Extensions.DependencyInjection.Extensions;"); - source.AppendLine(); - - source.Append("namespace "); - source.Append(RoutingNamespace); - source.AppendLine(";"); - - source.AppendLine(); - - source.Append("internal static class "); - source.Append(AddEndpointHandlersClassName); - source.AppendLine(); - - source.AppendLine("{"); - - source.Append(" internal static void "); - source.Append(AddEndpointHandlersMethodName); - source.AppendLine("(this IServiceCollection services)"); - - source.AppendLine(" {"); - - foreach (var className in nonStaticClassNames) - { - source.Append(" services.TryAddScoped<"); - source.Append(className); - source.Append(">();"); - source.AppendLine(); - } - - source.AppendLine(""" - } - } - """ - ); - - var sourceText = StringBuilderPool.ToStringAndReturn(source); - context.AddSource(AddEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); - } - - [SuppressMessage("Major Code Smell", "S3267:Loops should be simplified by calling the \"Select\" LINQ method", - Justification = "Manual loops avoid repeated allocations in the source generator." - )] - private static List GetDistinctNonStaticClassNames(ImmutableArray requestHandlers) - { - var classNames = new List(); - if (requestHandlers.IsDefaultOrEmpty) - return classNames; - - var seen = new HashSet(StringComparer.Ordinal); - foreach (var requestHandler in requestHandlers) - { - if (requestHandler.Class.IsStatic) - continue; - - var className = requestHandler.Class.Name; - if (seen.Add(className)) - classNames.Add(className); - } - - return classNames; - } - - private static StringBuilder GetAddEndpointHandlersStringBuilder(List nonStaticClassNames) - { - var estimate = 512L; - foreach (var className in nonStaticClassNames) - estimate += 36 + className.Length; - - estimate += Math.Max(256, nonStaticClassNames.Count * 12); - estimate = (long)(estimate * 1.10); - - if (estimate < 512) - estimate = 512; - else if (estimate > int.MaxValue) - estimate = int.MaxValue; - - return StringBuilderPool.Get((int)estimate); - } - - private static void GenerateUseEndpointHandlersClass(SourceProductionContext context, ImmutableArray requestHandlers) - { - context.CancellationToken.ThrowIfCancellationRequested(); - - var source = GetUseEndpointHandlersStringBuilder(requestHandlers); - source.AppendLine(FileHeader); - - source.AppendLine(); - - source.AppendLine("using Microsoft.AspNetCore.Builder;"); - source.AppendLine("using Microsoft.AspNetCore.Http;"); - source.AppendLine("using Microsoft.AspNetCore.Mvc;"); - source.AppendLine("using Microsoft.AspNetCore.Routing;"); - if (HasRateLimitedHandlers(requestHandlers)) - source.AppendLine("using Microsoft.AspNetCore.RateLimiting;"); - source.AppendLine("using Microsoft.Extensions.DependencyInjection;"); - source.AppendLine(); - - source.Append("namespace "); - source.Append(RoutingNamespace); - source.AppendLine(";"); - - source.AppendLine(); - - source.Append("internal static class "); - source.Append(UseEndpointHandlersClassName); - source.AppendLine(); - - source.AppendLine("{"); - - source.Append(" internal static IEndpointRouteBuilder "); - source.Append(UseEndpointHandlersMethodName); - source.AppendLine("(this IEndpointRouteBuilder builder)"); - - source.AppendLine(" {"); - - var groupedClasses = GetClassesWithMapGroups(requestHandlers); - - foreach (var groupedClass in groupedClasses) - { - source.Append(" var "); - source.Append(groupedClass.MapGroupBuilderIdentifier); - source.Append(" = builder.MapGroup("); - source.Append(StringLiteral(groupedClass.MapGroupPattern!)); - source.Append(')'); - AppendEndpointConfiguration(source, " ", groupedClass.Configuration, false); - source.AppendLine(";"); - } - - if (groupedClasses.Count > 0) - source.AppendLine(); - - for (var index = 0; index < requestHandlers.Length; index++) - { - if (index > 0) - source.AppendLine(); - - var requestHandler = requestHandlers[index]; - GenerateMapRequestHandler(source, requestHandler); - } - - source.AppendLine(""" - - return builder; - } - } - """ - ); - - var sourceText = StringBuilderPool.ToStringAndReturn(source); - context.AddSource(UseEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); - } - - private static bool HasRateLimitedHandlers(ImmutableArray requestHandlers) - { - foreach (var handler in requestHandlers) - { - if (handler.Configuration.RequireRateLimiting) - return true; - } - - return false; - } - - [SuppressMessage("Major Code Smell", "S3267:Loops should be simplified by calling the \"Select\" LINQ method", - Justification = "Manual loops avoid repeated allocations in the source generator." - )] - private static List GetClassesWithMapGroups(ImmutableArray requestHandlers) - { - var groupedClasses = new List(); - if (requestHandlers.IsDefaultOrEmpty) - return groupedClasses; - - var seen = new HashSet(StringComparer.Ordinal); - foreach (var handler in requestHandlers) - { - var handlerClass = handler.Class; - if (handlerClass.MapGroupPattern is null) - continue; - - if (seen.Add(handlerClass.Name)) - groupedClasses.Add(handlerClass); - } - - return groupedClasses; - } - - private static void GenerateMapRequestHandler(StringBuilder source, RequestHandler requestHandler) - { - var wrapWithConfigure = requestHandler.Class.HasConfigureMethod; - var configureAcceptsServiceProvider = requestHandler.Class.ConfigureMethodAcceptsServiceProvider; - var indent = wrapWithConfigure ? " " : " "; - var continuationIndent = indent + " "; - var routeBuilderIdentifier = requestHandler.Class.MapGroupBuilderIdentifier ?? "builder"; - - if (wrapWithConfigure) - { - source.Append(" "); - source.Append(requestHandler.Class.Name); - source.Append('.'); - source.Append(ConfigureMethodName); - source.AppendLine("("); - } - - var isFallback = string.Equals(requestHandler.HttpMethod, FallbackHttpMethod, StringComparison.Ordinal); - var mapMethodSuffix = isFallback ? null : GetMapMethodSuffix(requestHandler.HttpMethod); - - source.Append(indent); - if (isFallback) - { - source.Append(routeBuilderIdentifier); - source.Append(".MapFallback("); - if (!string.IsNullOrEmpty(requestHandler.Pattern)) - { - source.Append(StringLiteral(requestHandler.Pattern)); - source.Append(", "); - } - } - else - { - source.Append(routeBuilderIdentifier); - source.Append(".Map"); - source.Append(mapMethodSuffix ?? "Methods"); - source.Append('('); - source.Append(StringLiteral(requestHandler.Pattern)); - source.Append(", "); - if (mapMethodSuffix is null) - { - source.Append("new[] { \""); - source.Append(requestHandler.HttpMethod); - source.Append("\" }, "); - } - } - if (requestHandler.Method.IsStatic) - { - source.Append(requestHandler.Class.Name); - source.Append('.'); - source.Append(requestHandler.Method.Name); - } - else - { - source.Append("static "); - if (requestHandler.Method.IsAwaitable) - source.Append("async "); - source.Append("([FromServices] "); - source.Append(requestHandler.Class.Name); - source.Append(" handler"); - foreach (var parameter in requestHandler.Method.Parameters) - { - source.Append(", "); - source.Append(parameter.BindingPrefix); - source.Append(parameter.Type); - source.Append(' '); - source.Append(parameter.Name); - } - source.Append(") => "); - if (requestHandler.Method.IsAwaitable) - source.Append("await "); - source.Append("handler."); - source.Append(requestHandler.Method.Name); - source.Append('('); - for (var index = 0; index < requestHandler.Method.Parameters.Count; index++) - { - if (index > 0) - source.Append(", "); - var parameter = requestHandler.Method.Parameters[index]; - source.Append(parameter.Name); - } - source.Append(')'); - } - source.Append(')'); - - var configuration = requestHandler.Configuration; - if (requestHandler.Class.MapGroupPattern is null) - configuration = MergeEndpointConfigurations(requestHandler.Class.Configuration, configuration); - - AppendEndpointConfiguration(source, continuationIndent, configuration, true); - - if (wrapWithConfigure && configureAcceptsServiceProvider) - { - source.AppendLine(","); - source.Append(indent); - source.Append("builder.ServiceProvider"); - } - - if (wrapWithConfigure) - { - source.AppendLine(); - source.Append(" );"); - source.AppendLine(); - } - else - { - source.AppendLine(";"); - } - } - - private static void AppendEndpointConfiguration(StringBuilder source, string indent, EndpointConfiguration configuration, bool includeNameAndDisplayName) - { - var metadata = configuration.Metadata; - - if (includeNameAndDisplayName && !string.IsNullOrEmpty(metadata.Name)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithName("); - source.Append(StringLiteral(metadata.Name)); - source.Append(')'); - } - - if (includeNameAndDisplayName && !string.IsNullOrEmpty(metadata.DisplayName)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithDisplayName("); - source.Append(StringLiteral(metadata.DisplayName)); - source.Append(')'); - } - - if (!string.IsNullOrEmpty(metadata.Summary)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithSummary("); - source.Append(StringLiteral(metadata.Summary)); - source.Append(')'); - } - - if (!string.IsNullOrEmpty(metadata.Description)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithDescription("); - source.Append(StringLiteral(metadata.Description)); - source.Append(')'); - } - - if (!string.IsNullOrEmpty(configuration.EndpointGroupName)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithGroupName("); - source.Append(StringLiteral(configuration.EndpointGroupName)); - source.Append(')'); - } - - if (configuration.Order is { } orderValue) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithOrder("); - source.Append(orderValue); - source.Append(')'); - } - - if (metadata.ExcludeFromDescription) - { - source.AppendLine(); - source.Append(indent); - source.Append(".ExcludeFromDescription()"); - } - - if (metadata.Tags is { Count: > 0 }) - { - source.AppendLine(); - source.Append(indent); - source.Append(".WithTags("); - AppendCommaSeparatedLiterals(source, metadata.Tags.Value); - source.Append(')'); - } - - if (metadata.Accepts is { Count: > 0 }) - foreach (var accepts in metadata.Accepts.Value) - { - source.AppendLine(); - source.Append(indent); - source.Append(".Accepts<"); - source.Append(accepts.RequestType); - source.Append('>'); - source.Append('('); - if (accepts.IsOptional) - source.Append("isOptional: true, "); - source.Append(StringLiteral(accepts.ContentType)); - AppendAdditionalContentTypes(source, accepts.AdditionalContentTypes); - source.Append(')'); - } - - if (metadata.Produces is { Count: > 0 }) - foreach (var produces in metadata.Produces.Value) - { - source.AppendLine(); - source.Append(indent); - source.Append(".Produces<"); - source.Append(produces.ResponseType); - source.Append('>'); - source.Append('('); - source.Append(produces.StatusCode); - AppendOptionalContentTypes(source, produces.ContentType, produces.AdditionalContentTypes); - source.Append(')'); - } - - if (metadata.ProducesProblem is { Count: > 0 }) - foreach (var producesProblem in metadata.ProducesProblem.Value) - { - source.AppendLine(); - source.Append(indent); - source.Append(".ProducesProblem("); - source.Append(producesProblem.StatusCode); - AppendOptionalContentTypes(source, producesProblem.ContentType, producesProblem.AdditionalContentTypes); - source.Append(')'); - } - - if (metadata.ProducesValidationProblem is { Count: > 0 }) - foreach (var producesValidationProblem in metadata.ProducesValidationProblem.Value) - { - source.AppendLine(); - source.Append(indent); - source.Append(".ProducesValidationProblem("); - source.Append(producesValidationProblem.StatusCode); - AppendOptionalContentTypes(source, producesValidationProblem.ContentType, producesValidationProblem.AdditionalContentTypes); - source.Append(')'); - } - - if (configuration.RequireAuthorization) - { - source.AppendLine(); - if (configuration.AuthorizationPolicies is { Count: > 0 }) - { - source.Append(indent); - source.Append(".RequireAuthorization("); - AppendCommaSeparatedLiterals(source, configuration.AuthorizationPolicies.Value); - source.Append(')'); - } - else - { - source.Append(indent); - source.Append(".RequireAuthorization()"); - } - } - - if (configuration.RequireCors) - { - source.AppendLine(); - source.Append(indent); - if (!string.IsNullOrEmpty(configuration.CorsPolicyName)) - { - source.Append(".RequireCors("); - source.Append(StringLiteral(configuration.CorsPolicyName)); - source.Append(')'); - } - else - { - source.Append(".RequireCors()"); - } - } - - if (configuration.RequiredHosts is { Count: > 0 }) - { - source.AppendLine(); - source.Append(indent); - source.Append(".RequireHost("); - AppendCommaSeparatedLiterals(source, configuration.RequiredHosts.Value); - source.Append(')'); - } - - if (configuration.RequireRateLimiting && !string.IsNullOrEmpty(configuration.RateLimitingPolicyName)) - { - source.AppendLine(); - source.Append(indent); - source.Append(".RequireRateLimiting("); - source.Append(StringLiteral(configuration.RateLimitingPolicyName)); - source.Append(')'); - } - - if (configuration.DisableAntiforgery) - { - source.AppendLine(); - source.Append(indent); - source.Append(".DisableAntiforgery()"); - } - - if (configuration.AllowAnonymous) - { - source.AppendLine(); - source.Append(indent); - source.Append(".AllowAnonymous()"); - } - - if (configuration.ShortCircuit) - { - source.AppendLine(); - source.Append(indent); - source.Append(".ShortCircuit()"); - } - - if (configuration.DisableValidation) - { - source.AppendLine(); - source.Append(indent); - source.Append(".DisableValidation()"); - source.AppendLine(); - } - - if (configuration.DisableRequestTimeout) - { - source.AppendLine(); - source.Append(indent); - source.Append(".DisableRequestTimeout()"); - } - else if (configuration.WithRequestTimeout) - { - source.AppendLine(); - source.Append(indent); - if (!string.IsNullOrEmpty(configuration.RequestTimeoutPolicyName)) - { - source.Append(".WithRequestTimeout("); - source.Append(StringLiteral(configuration.RequestTimeoutPolicyName)); - source.Append(')'); - } - else - { - source.Append(".WithRequestTimeout()"); - } - } - - if (configuration.EndpointFilterTypes is { Count: > 0 }) - foreach (var filterType in configuration.EndpointFilterTypes.Value) - { - source.AppendLine(); - source.Append(indent); - source.Append(".AddEndpointFilter<"); - source.Append(filterType); - source.Append(">()"); - } - } - - private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfiguration classConfiguration, EndpointConfiguration methodConfiguration) - { - var metadata = MergeRequestHandlerMetadata(classConfiguration.Metadata, methodConfiguration.Metadata); - var authorizationPolicies = MergeDistinctStrings(classConfiguration.AuthorizationPolicies, methodConfiguration.AuthorizationPolicies); - var requiredHosts = MergeDistinctStrings(classConfiguration.RequiredHosts, methodConfiguration.RequiredHosts); - var endpointFilterTypes = ConcatEquatable(classConfiguration.EndpointFilterTypes, methodConfiguration.EndpointFilterTypes); - var requireAuthorization = classConfiguration.RequireAuthorization || methodConfiguration.RequireAuthorization; - var disableAntiforgery = classConfiguration.DisableAntiforgery || methodConfiguration.DisableAntiforgery; - var allowAnonymous = classConfiguration.AllowAnonymous || methodConfiguration.AllowAnonymous; - var requireCors = classConfiguration.RequireCors || methodConfiguration.RequireCors; - var corsPolicyName = methodConfiguration.CorsPolicyName ?? classConfiguration.CorsPolicyName; - var requireRateLimiting = classConfiguration.RequireRateLimiting || methodConfiguration.RequireRateLimiting; - var rateLimitingPolicyName = methodConfiguration.RateLimitingPolicyName ?? classConfiguration.RateLimitingPolicyName; - var shortCircuit = classConfiguration.ShortCircuit || methodConfiguration.ShortCircuit; - var disableValidation = classConfiguration.DisableValidation || methodConfiguration.DisableValidation; - var disableRequestTimeout = classConfiguration.DisableRequestTimeout || methodConfiguration.DisableRequestTimeout; - var withRequestTimeout = classConfiguration.WithRequestTimeout || methodConfiguration.WithRequestTimeout; - string? requestTimeoutPolicyName = null; - if (methodConfiguration.WithRequestTimeout) - requestTimeoutPolicyName = methodConfiguration.RequestTimeoutPolicyName; - else if (classConfiguration.WithRequestTimeout) - requestTimeoutPolicyName = classConfiguration.RequestTimeoutPolicyName; - - if (disableRequestTimeout) - { - withRequestTimeout = false; - requestTimeoutPolicyName = null; - } - - var order = methodConfiguration.Order ?? classConfiguration.Order; - var endpointGroupName = methodConfiguration.EndpointGroupName ?? classConfiguration.EndpointGroupName; - - return new EndpointConfiguration(metadata, requireAuthorization, authorizationPolicies, disableAntiforgery, allowAnonymous, requireCors, corsPolicyName, - requiredHosts, requireRateLimiting, rateLimitingPolicyName, endpointFilterTypes, shortCircuit, disableValidation, disableRequestTimeout, - withRequestTimeout, requestTimeoutPolicyName, order, endpointGroupName - ); - } - - private static RequestHandlerMetadata MergeRequestHandlerMetadata(RequestHandlerMetadata classMetadata, RequestHandlerMetadata methodMetadata) - { - return new RequestHandlerMetadata(methodMetadata.Name ?? classMetadata.Name, methodMetadata.DisplayName ?? classMetadata.DisplayName, - methodMetadata.Summary ?? classMetadata.Summary, methodMetadata.Description ?? classMetadata.Description, - MergeDistinctStrings(classMetadata.Tags, methodMetadata.Tags), ConcatEquatable(classMetadata.Accepts, methodMetadata.Accepts), - ConcatEquatable(classMetadata.Produces, methodMetadata.Produces), ConcatEquatable(classMetadata.ProducesProblem, methodMetadata.ProducesProblem), - ConcatEquatable(classMetadata.ProducesValidationProblem, methodMetadata.ProducesValidationProblem), - classMetadata.ExcludeFromDescription || methodMetadata.ExcludeFromDescription - ); - } - - private static EquatableImmutableArray? MergeDistinctStrings(EquatableImmutableArray? first, EquatableImmutableArray? second) - { - if (first is not { Count: > 0 }) - return second; - if (second is not { Count: > 0 }) - return first; - - var merged = MergeUnion(first, second.Value); - return merged.Count > 0 ? merged : null; - } - - private static EquatableImmutableArray? ConcatEquatable(EquatableImmutableArray? first, EquatableImmutableArray? second) - { - if (first is not { Count: > 0 }) - return second; - if (second is not { Count: > 0 }) - return first; - - var builder = ImmutableArray.CreateBuilder(first.Value.Count + second.Value.Count); - builder.AddRange(first.Value); - builder.AddRange(second.Value); - return builder.ToEquatableImmutableArray(); - } - - private static string? GetMapMethodSuffix(string httpMethod) - { - return httpMethod switch - { - "GET" => "Get", - "POST" => "Post", - "PUT" => "Put", - "DELETE" => "Delete", - "PATCH" => "Patch", - _ => null, - }; - } - - private static string GetBindingSourceAttribute(BindingSource source, string? key, string? bindingName) - { - return source switch - { - BindingSource.None => "", - BindingSource.FromRoute => FormatBindingAttribute("FromRoute", bindingName), - BindingSource.FromQuery => FormatBindingAttribute("FromQuery", bindingName), - BindingSource.FromHeader => FormatBindingAttribute("FromHeader", bindingName), - BindingSource.FromBody => FormatBindingAttribute("FromBody", bindingName), - BindingSource.FromForm => FormatBindingAttribute("FromForm", bindingName), - BindingSource.FromServices => "[FromServices] ", - BindingSource.FromKeyedServices => $"[FromKeyedServices({key})] ", - BindingSource.AsParameters => "[AsParameters] ", - _ => throw new NotImplementedException(), - }; - } - - private static string FormatBindingAttribute(string attributeName, string? bindingName) - { - if (bindingName is null) - return $"[{attributeName}] "; - - return $"[{attributeName}(Name = {StringLiteral(bindingName)})] "; - } - - private static StringBuilder GetUseEndpointHandlersStringBuilder(ImmutableArray requestHandlers) - { - const int baseSize = 4096; - const int perHandler = 512; - - var handlerCount = Math.Max(requestHandlers.Length, 0); - var estimate = baseSize + (long)perHandler * handlerCount; - estimate = (long)(estimate * 1.10); - - if (estimate > int.MaxValue) - estimate = int.MaxValue; - - return StringBuilderPool.Get((int)Math.Max(baseSize, estimate)); - } - - [SuppressMessage("Globalization", "CA1308: Normalize strings to uppercase", Justification = "C# boolean literals must be lowercase.")] - private static string ConstLiteral(TypedConstant tc) - { - if (tc.IsNull) - return "null"; - var v = tc.Value; - var t = tc.Type; - if (t is null) - return "null"; - - if (t.TypeKind != TypeKind.Enum) - return t.SpecialType switch - { - SpecialType.System_String => StringLiteral((string?)v), - SpecialType.System_Char => $"'{EscapeChar((char)v!)}'", - SpecialType.System_Boolean => ((bool)v!).ToString() - .ToLowerInvariant(), - SpecialType.System_Double => ((double)v!).ToString("R", CultureInfo.InvariantCulture), - SpecialType.System_Single => ((float)v!).ToString("R", CultureInfo.InvariantCulture) + "f", - SpecialType.System_Decimal => ((decimal)v!).ToString(CultureInfo.InvariantCulture) + "m", - SpecialType.System_SByte => ((sbyte)v!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Byte => ((byte)v!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Int16 => ((short)v!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_UInt16 => ((ushort)v!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Int32 => ((int)v!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_UInt32 => ((uint)v!).ToString(CultureInfo.InvariantCulture) + "u", - SpecialType.System_Int64 => ((long)v!).ToString(CultureInfo.InvariantCulture) + "L", - SpecialType.System_UInt64 => ((ulong)v!).ToString(CultureInfo.InvariantCulture) + "UL", - _ => StringLiteral(v?.ToString()), - }; - - var field = t.GetMembers() - .OfType() - .FirstOrDefault(f => f.HasConstantValue && Equals(f.ConstantValue, v)); - - if (field is not null) - return $"{t.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}.{field.Name}"; - - var underlying = ((INamedTypeSymbol)t).EnumUnderlyingType!; - var num = IntegralLiteral(v, underlying.SpecialType); - return $"({t.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}){num}"; - } - - private static string IntegralLiteral(object? value, SpecialType underlying) - { - return underlying switch - { - SpecialType.System_SByte => ((sbyte)value!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Byte => ((byte)value!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Int16 => ((short)value!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_UInt16 => ((ushort)value!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_Int32 => ((int)value!).ToString(CultureInfo.InvariantCulture), - SpecialType.System_UInt32 => ((uint)value!).ToString(CultureInfo.InvariantCulture) + "u", - SpecialType.System_Int64 => ((long)value!).ToString(CultureInfo.InvariantCulture) + "L", - SpecialType.System_UInt64 => ((ulong)value!).ToString(CultureInfo.InvariantCulture) + "UL", - _ => "0", - }; - } - - private static string StringLiteral(string? value) - { - if (value is null) - return "null"; - - var firstEscapeIndex = -1; - for (var i = 0; i < value.Length; i++) - { - var c = value[i]; - if (c == '\"' || c == '\\' || c == '\n' || c == '\r' || c == '\t' || c == '\0' || char.IsControl(c)) - { - firstEscapeIndex = i; - break; - } - } - - if (firstEscapeIndex < 0) - return string.Concat("\"", value, "\""); - - var sb = StringBuilderPool.Get(value.Length + 2); - sb.Append('"'); - if (firstEscapeIndex > 0) - sb.Append(value, 0, firstEscapeIndex); - - for (var i = firstEscapeIndex; i < value.Length; i++) - { - var c = value[i]; - switch (c) - { - case '\"': - sb.Append("\\\""); - break; - case '\\': - sb.Append("\\\\"); - break; - case '\n': - sb.Append("\\n"); - break; - case '\r': - sb.Append("\\r"); - break; - case '\t': - sb.Append("\\t"); - break; - case '\0': - sb.Append("\\0"); - break; - default: - if (char.IsControl(c)) - sb.Append("\\u") - .Append(((int)c).ToString("x4", CultureInfo.InvariantCulture)); - else - sb.Append(c); - - break; - } - } - - sb.Append('"'); - return StringBuilderPool.ToStringAndReturn(sb); - } - - private static void AppendAdditionalContentTypes(StringBuilder source, EquatableImmutableArray? additionalContentTypes) - { - if (additionalContentTypes is not { Count: > 0 }) - return; - - foreach (var additional in additionalContentTypes.Value) - { - source.Append(", "); - source.Append(StringLiteral(additional)); - } - } - - private static void AppendCommaSeparatedLiterals(StringBuilder source, EquatableImmutableArray values) - { - if (values.Count == 0) - return; - - source.Append(StringLiteral(values[0])); - for (var i = 1; i < values.Count; i++) - { - source.Append(", "); - source.Append(StringLiteral(values[i])); - } - } - - private static void AppendOptionalContentTypes(StringBuilder source, string? contentType, EquatableImmutableArray? additionalContentTypes) - { - if (string.IsNullOrEmpty(contentType) && additionalContentTypes is not { Count: > 0 }) - return; - - source.Append(", "); - source.Append(contentType is { Length: > 0 } ? StringLiteral(contentType) : "null"); - AppendAdditionalContentTypes(source, additionalContentTypes); - } - - private static string EscapeChar(char c) - { - return c switch - { - '\'' => "\\'", - '\\' => "\\\\", - '\n' => "\\n", - '\r' => "\\r", - '\t' => "\\t", - '\0' => "\\0", - _ when char.IsControl(c) => "\\u" + ((int)c).ToString("x4", CultureInfo.InvariantCulture), - _ => c.ToString(), - }; } } diff --git a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs new file mode 100644 index 0000000..93ba94e --- /dev/null +++ b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs @@ -0,0 +1,667 @@ +using System.Collections.Immutable; +using System.Text; +using GeneratedEndpoints.Common; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.Text; +using static GeneratedEndpoints.Common.Constants; + +namespace GeneratedEndpoints; + +// ReSharper disable ForCanBeConvertedToForeach +// ReSharper disable LoopCanBeConvertedToQuery +// Do not refactor, use for loop to avoid allocations. + +internal static class UseEndpointHandlersGenerator +{ + public static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) + { + context.CancellationToken.ThrowIfCancellationRequested(); + + var source = GetUseEndpointHandlersStringBuilder(requestHandlers); + source.AppendLine(FileHeader); + + source.AppendLine(); + + source.AppendLine("using Microsoft.AspNetCore.Builder;"); + source.AppendLine("using Microsoft.AspNetCore.Http;"); + source.AppendLine("using Microsoft.AspNetCore.Mvc;"); + source.AppendLine("using Microsoft.AspNetCore.Routing;"); + if (HasRateLimitedHandlers(requestHandlers)) + source.AppendLine("using Microsoft.AspNetCore.RateLimiting;"); + source.AppendLine("using Microsoft.Extensions.DependencyInjection;"); + source.AppendLine(); + + source.Append("namespace "); + source.Append(RoutingNamespace); + source.AppendLine(";"); + + source.AppendLine(); + + source.Append("internal static class "); + source.Append(UseEndpointHandlersClassName); + source.AppendLine(); + + source.AppendLine("{"); + + source.Append(" internal static IEndpointRouteBuilder "); + source.Append(UseEndpointHandlersMethodName); + source.AppendLine("(this IEndpointRouteBuilder builder)"); + + source.AppendLine(" {"); + + var groupedClasses = GetClassesWithMapGroups(requestHandlers); + + for (var index = 0; index < groupedClasses.Count; index++) + { + var groupedClass = groupedClasses[index]; + source.Append(" var "); + source.Append(groupedClass.Configuration.GroupIdentifier); + source.Append(" = builder.MapGroup("); + source.Append(groupedClass.Configuration.GroupPattern!.ToStringLiteral()); + source.Append(')'); + AppendEndpointConfiguration(source, " ", groupedClass.Configuration); + source.AppendLine(";"); + } + + if (groupedClasses.Count > 0) + source.AppendLine(); + + for (var index = 0; index < requestHandlers.Count; index++) + { + if (index > 0) + source.AppendLine(); + + var requestHandler = requestHandlers[index]; + GenerateMapRequestHandler(source, requestHandler); + } + + source.AppendLine(""" + + return builder; + } + } + """ + ); + + var sourceText = StringBuilderPool.ToStringAndReturn(source); + context.AddSource(UseEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); + } + + private static bool HasRateLimitedHandlers(EquatableImmutableArray requestHandlers) + { + for (var index = 0; index < requestHandlers.Count; index++) + { + var handler = requestHandlers[index]; + if (handler.Class.Configuration.RequireRateLimiting || handler.Method.Configuration.RequireRateLimiting) + return true; + } + + return false; + } + + private static List GetClassesWithMapGroups(EquatableImmutableArray requestHandlers) + { + var groupedClasses = new List(); + if (requestHandlers.Count == 0) + return groupedClasses; + + var seen = new HashSet(StringComparer.Ordinal); + for (var index = 0; index < requestHandlers.Count; index++) + { + var handler = requestHandlers[index]; + var handlerClass = handler.Class; + if (handlerClass.Configuration.GroupPattern is null) + continue; + + if (seen.Add(handlerClass.Name)) + groupedClasses.Add(handlerClass); + } + + return groupedClasses; + } + + private static void GenerateMapRequestHandler(StringBuilder source, RequestHandler requestHandler) + { + var wrapWithConfigure = requestHandler.Class.HasConfigureMethod; + var configureAcceptsServiceProvider = requestHandler.Class.ConfigureMethodAcceptsServiceProvider; + var indent = wrapWithConfigure ? " " : " "; + var continuationIndent = indent + " "; + var routeBuilderIdentifier = requestHandler.Class.Configuration.GroupIdentifier ?? "builder"; + + if (wrapWithConfigure) + { + source.Append(" "); + source.Append(requestHandler.Class.Name); + source.Append('.'); + source.Append(ConfigureMethodName); + source.AppendLine("("); + } + + var isFallback = string.Equals(requestHandler.HttpMethod, FallbackHttpMethod, StringComparison.Ordinal); + var mapMethodSuffix = isFallback ? null : GetMapMethodSuffix(requestHandler.HttpMethod); + + source.Append(indent); + if (isFallback) + { + source.Append(routeBuilderIdentifier); + source.Append(".MapFallback("); + if (!string.IsNullOrEmpty(requestHandler.Pattern)) + { + source.Append(requestHandler.Pattern.ToStringLiteral()); + source.Append(", "); + } + } + else + { + source.Append(routeBuilderIdentifier); + source.Append(".Map"); + source.Append(mapMethodSuffix ?? "Methods"); + source.Append('('); + source.Append(requestHandler.Pattern.ToStringLiteral()); + source.Append(", "); + if (mapMethodSuffix is null) + { + source.Append("new[] { \""); + source.Append(requestHandler.HttpMethod); + source.Append("\" }, "); + } + } + if (requestHandler.Method.IsStatic) + { + source.Append(requestHandler.Class.Name); + source.Append('.'); + source.Append(requestHandler.Method.Name); + } + else + { + source.Append("static ([FromServices] "); + source.Append(requestHandler.Class.Name); + source.Append(" handler"); + foreach (var parameter in requestHandler.Method.Parameters) + { + source.Append(", "); + source.Append(parameter.BindingPrefix); + source.Append(parameter.Type); + source.Append(' '); + source.Append(parameter.Name); + } + source.Append(") => handler."); + source.Append(requestHandler.Method.Name); + source.Append('('); + for (var index = 0; index < requestHandler.Method.Parameters.Count; index++) + { + if (index > 0) + source.Append(", "); + var parameter = requestHandler.Method.Parameters[index]; + source.Append(parameter.Name); + } + source.Append(')'); + } + source.Append(')'); + + var configuration = requestHandler.Method.Configuration; + if (requestHandler.Class.Configuration.GroupPattern is null) + configuration = MergeEndpointConfigurations(requestHandler.Class.Configuration, configuration); + + if (!string.IsNullOrEmpty(requestHandler.Name)) + { + source.AppendLine(); + source.Append(continuationIndent); + source.Append(".WithName("); + source.Append(requestHandler.Name.ToStringLiteral()); + source.Append(')'); + } + AppendEndpointConfiguration(source, continuationIndent, configuration); + + if (wrapWithConfigure && configureAcceptsServiceProvider) + { + source.AppendLine(","); + source.Append(indent); + source.Append("builder.ServiceProvider"); + } + + if (wrapWithConfigure) + { + source.AppendLine(); + source.Append(" );"); + source.AppendLine(); + } + else + { + source.AppendLine(";"); + } + } + + private static void AppendEndpointConfiguration(StringBuilder source, string indent, EndpointConfiguration configuration) + { + if (!string.IsNullOrEmpty(configuration.DisplayName)) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithDisplayName("); + source.Append(configuration.DisplayName.ToStringLiteral()); + source.Append(')'); + } + + if (!string.IsNullOrEmpty(configuration.Summary)) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithSummary("); + source.Append(configuration.Summary.ToStringLiteral()); + source.Append(')'); + } + + if (!string.IsNullOrEmpty(configuration.Description)) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithDescription("); + source.Append(configuration.Description.ToStringLiteral()); + source.Append(')'); + } + + if (!string.IsNullOrEmpty(configuration.GroupName)) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithGroupName("); + source.Append(configuration.GroupName.ToStringLiteral()); + source.Append(')'); + } + + if (configuration.Order is { } orderValue) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithOrder("); + source.Append(orderValue); + source.Append(')'); + } + + if (configuration.ExcludeFromDescription) + { + source.AppendLine(); + source.Append(indent); + source.Append(".ExcludeFromDescription()"); + } + + if (configuration.Tags is { Count: > 0 }) + { + source.AppendLine(); + source.Append(indent); + source.Append(".WithTags("); + AppendCommaSeparatedLiterals(source, configuration.Tags.Value); + source.Append(')'); + } + + if (configuration.Accepts is { Count: > 0 }) + foreach (var accepts in configuration.Accepts.Value) + { + source.AppendLine(); + source.Append(indent); + source.Append(".Accepts<"); + source.Append(accepts.RequestType); + source.Append('>'); + source.Append('('); + if (accepts.IsOptional) + source.Append("isOptional: true, "); + source.Append(accepts.ContentType.ToStringLiteral()); + AppendAdditionalContentTypes(source, accepts.AdditionalContentTypes); + source.Append(')'); + } + + if (configuration.Produces is { Count: > 0 }) + foreach (var produces in configuration.Produces.Value) + { + source.AppendLine(); + source.Append(indent); + source.Append(".Produces<"); + source.Append(produces.ResponseType); + source.Append('>'); + source.Append('('); + source.Append(produces.StatusCode); + AppendOptionalContentTypes(source, produces.ContentType, produces.AdditionalContentTypes); + source.Append(')'); + } + + if (configuration.ProducesProblem is { Count: > 0 }) + foreach (var producesProblem in configuration.ProducesProblem.Value) + { + source.AppendLine(); + source.Append(indent); + source.Append(".ProducesProblem("); + source.Append(producesProblem.StatusCode); + AppendOptionalContentTypes(source, producesProblem.ContentType, producesProblem.AdditionalContentTypes); + source.Append(')'); + } + + if (configuration.ProducesValidationProblem is { Count: > 0 }) + foreach (var producesValidationProblem in configuration.ProducesValidationProblem.Value) + { + source.AppendLine(); + source.Append(indent); + source.Append(".ProducesValidationProblem("); + source.Append(producesValidationProblem.StatusCode); + AppendOptionalContentTypes(source, producesValidationProblem.ContentType, producesValidationProblem.AdditionalContentTypes); + source.Append(')'); + } + + if (configuration.RequireAuthorization) + { + source.AppendLine(); + if (configuration.AuthorizationPolicies is { Count: > 0 }) + { + source.Append(indent); + source.Append(".RequireAuthorization("); + AppendCommaSeparatedLiterals(source, configuration.AuthorizationPolicies.Value); + source.Append(')'); + } + else + { + source.Append(indent); + source.Append(".RequireAuthorization()"); + } + } + + if (configuration.RequireCors) + { + source.AppendLine(); + source.Append(indent); + if (!string.IsNullOrEmpty(configuration.CorsPolicyName)) + { + source.Append(".RequireCors("); + source.Append(configuration.CorsPolicyName.ToStringLiteral()); + source.Append(')'); + } + else + { + source.Append(".RequireCors()"); + } + } + + if (configuration.RequiredHosts is { Count: > 0 }) + { + source.AppendLine(); + source.Append(indent); + source.Append(".RequireHost("); + AppendCommaSeparatedLiterals(source, configuration.RequiredHosts.Value); + source.Append(')'); + } + + if (configuration.RequireRateLimiting && !string.IsNullOrEmpty(configuration.RateLimitingPolicyName)) + { + source.AppendLine(); + source.Append(indent); + source.Append(".RequireRateLimiting("); + source.Append(configuration.RateLimitingPolicyName.ToStringLiteral()); + source.Append(')'); + } + + if (configuration.DisableAntiforgery) + { + source.AppendLine(); + source.Append(indent); + source.Append(".DisableAntiforgery()"); + } + + if (configuration.AllowAnonymous) + { + source.AppendLine(); + source.Append(indent); + source.Append(".AllowAnonymous()"); + } + + if (configuration.ShortCircuit) + { + source.AppendLine(); + source.Append(indent); + source.Append(".ShortCircuit()"); + } + + if (configuration.DisableValidation) + { + source.AppendLine(); + source.Append(indent); + source.Append(".DisableValidation()"); + source.AppendLine(); + } + + if (configuration.DisableRequestTimeout) + { + source.AppendLine(); + source.Append(indent); + source.Append(".DisableRequestTimeout()"); + } + else if (configuration.WithRequestTimeout) + { + source.AppendLine(); + source.Append(indent); + if (!string.IsNullOrEmpty(configuration.RequestTimeoutPolicyName)) + { + source.Append(".WithRequestTimeout("); + source.Append(configuration.RequestTimeoutPolicyName.ToStringLiteral()); + source.Append(')'); + } + else + { + source.Append(".WithRequestTimeout()"); + } + } + + if (configuration.EndpointFilterTypes is { Count: > 0 }) + foreach (var filterType in configuration.EndpointFilterTypes.Value) + { + source.AppendLine(); + source.Append(indent); + source.Append(".AddEndpointFilter<"); + source.Append(filterType); + source.Append(">()"); + } + } + + private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfiguration classConfiguration, EndpointConfiguration methodConfiguration) + { + var displayName = methodConfiguration.DisplayName ?? classConfiguration.DisplayName; + var summary = methodConfiguration.Summary ?? classConfiguration.Summary; + var description = methodConfiguration.Description ?? classConfiguration.Description; + + var tags = MergeDistinctStrings(methodConfiguration.Tags, classConfiguration.Tags); + var accepts = ConcatEquatable(methodConfiguration.Accepts, classConfiguration.Accepts); + var produces = ConcatEquatable(methodConfiguration.Produces, classConfiguration.Produces); + var producesProblem = ConcatEquatable(methodConfiguration.ProducesProblem, classConfiguration.ProducesProblem); + var producesValidationProblem = ConcatEquatable(methodConfiguration.ProducesValidationProblem, classConfiguration.ProducesValidationProblem); + + var excludeFromDescription = methodConfiguration.ExcludeFromDescription || classConfiguration.ExcludeFromDescription; + + var authorizationPolicies = MergeDistinctStrings(methodConfiguration.AuthorizationPolicies, classConfiguration.AuthorizationPolicies); + var requiredHosts = MergeDistinctStrings(methodConfiguration.RequiredHosts, classConfiguration.RequiredHosts); + var endpointFilterTypes = ConcatEquatable(methodConfiguration.EndpointFilterTypes, classConfiguration.EndpointFilterTypes); + + var requireAuthorization = methodConfiguration.RequireAuthorization || classConfiguration.RequireAuthorization; + var disableAntiforgery = methodConfiguration.DisableAntiforgery || classConfiguration.DisableAntiforgery; + var allowAnonymous = methodConfiguration.AllowAnonymous || classConfiguration.AllowAnonymous; + + var requireCors = methodConfiguration.RequireCors || classConfiguration.RequireCors; + var corsPolicyName = methodConfiguration.CorsPolicyName ?? classConfiguration.CorsPolicyName; + + var requireRateLimiting = methodConfiguration.RequireRateLimiting || classConfiguration.RequireRateLimiting; + var rateLimitingPolicyName = methodConfiguration.RateLimitingPolicyName ?? classConfiguration.RateLimitingPolicyName; + + var shortCircuit = methodConfiguration.ShortCircuit || classConfiguration.ShortCircuit; + var disableValidation = methodConfiguration.DisableValidation || classConfiguration.DisableValidation; + var disableRequestTimeout = methodConfiguration.DisableRequestTimeout || classConfiguration.DisableRequestTimeout; + var withRequestTimeout = methodConfiguration.WithRequestTimeout || classConfiguration.WithRequestTimeout; + + var groupIdentifier = methodConfiguration.GroupIdentifier ?? classConfiguration.GroupIdentifier; + var groupPattern = methodConfiguration.GroupPattern ?? classConfiguration.GroupPattern; + var groupName = methodConfiguration.GroupName ?? classConfiguration.GroupName; + + string? requestTimeoutPolicyName = null; + if (methodConfiguration.WithRequestTimeout) + requestTimeoutPolicyName = methodConfiguration.RequestTimeoutPolicyName; + else if (classConfiguration.WithRequestTimeout) + requestTimeoutPolicyName = classConfiguration.RequestTimeoutPolicyName; + + if (disableRequestTimeout) + { + withRequestTimeout = false; + requestTimeoutPolicyName = null; + } + + var order = methodConfiguration.Order ?? classConfiguration.Order; + + return new EndpointConfiguration + { + DisplayName = displayName, + Summary = summary, + Description = description, + Tags = tags, + Accepts = accepts, + Produces = produces, + ProducesProblem = producesProblem, + ProducesValidationProblem = producesValidationProblem, + ExcludeFromDescription = excludeFromDescription, + RequireAuthorization = requireAuthorization, + AuthorizationPolicies = authorizationPolicies, + DisableAntiforgery = disableAntiforgery, + AllowAnonymous = allowAnonymous, + RequireCors = requireCors, + CorsPolicyName = corsPolicyName, + RequiredHosts = requiredHosts, + RequireRateLimiting = requireRateLimiting, + RateLimitingPolicyName = rateLimitingPolicyName, + EndpointFilterTypes = endpointFilterTypes, + ShortCircuit = shortCircuit, + DisableValidation = disableValidation, + DisableRequestTimeout = disableRequestTimeout, + WithRequestTimeout = withRequestTimeout, + RequestTimeoutPolicyName = requestTimeoutPolicyName, + Order = order, + GroupIdentifier = groupIdentifier, + GroupPattern = groupPattern, + GroupName = groupName, + }; + } + + private static EquatableImmutableArray? MergeDistinctStrings(EquatableImmutableArray? first, EquatableImmutableArray? second) + { + if (first is not { Count: > 0 }) + return second; + if (second is not { Count: > 0 }) + return first; + + var merged = MergeUnion(first, second.Value); + return merged.Count > 0 ? merged : null; + } + + private static EquatableImmutableArray MergeUnion(EquatableImmutableArray? existing, EquatableImmutableArray values) + { + List? list = null; + HashSet? seen = null; + + if (existing is { Count: > 0 }) + { + var count = existing.Value.Count; + list = new List(count + 4); + list.AddRange(existing.Value); + seen = new HashSet(existing.Value, StringComparer.OrdinalIgnoreCase); + } + + foreach (var value in values) + { + var normalized = value.NormalizeOptionalString(); + if (normalized is not { Length: > 0 }) + continue; + + seen ??= new HashSet(StringComparer.OrdinalIgnoreCase); + if (!seen.Add(normalized)) + continue; + + list ??= []; + list.Add(normalized); + } + + return list?.ToEquatableImmutableArray() ?? EquatableImmutableArray.Empty; + } + + private static EquatableImmutableArray? ConcatEquatable(EquatableImmutableArray? first, EquatableImmutableArray? second) + { + if (first is not { Count: > 0 }) + return second; + if (second is not { Count: > 0 }) + return first; + + var builder = ImmutableArray.CreateBuilder(first.Value.Count + second.Value.Count); + builder.AddRange(first.Value); + builder.AddRange(second.Value); + return builder.ToEquatableImmutableArray(); + } + + private static string? GetMapMethodSuffix(string httpMethod) + { + return httpMethod switch + { + "GET" => "Get", + "POST" => "Post", + "PUT" => "Put", + "DELETE" => "Delete", + "PATCH" => "Patch", + _ => null, + }; + } + + private static StringBuilder GetUseEndpointHandlersStringBuilder(EquatableImmutableArray requestHandlers) + { + const int baseSize = 4096; + const int perHandler = 512; + + var handlerCount = Math.Max(requestHandlers.Count, 0); + var estimate = baseSize + (long)perHandler * handlerCount; + estimate = (long)(estimate * 1.10); + + if (estimate > int.MaxValue) + estimate = int.MaxValue; + + return StringBuilderPool.Get((int)Math.Max(baseSize, estimate)); + } + + private static void AppendAdditionalContentTypes(StringBuilder source, EquatableImmutableArray? additionalContentTypes) + { + if (additionalContentTypes is not { Count: > 0 }) + return; + + foreach (var additional in additionalContentTypes.Value) + { + source.Append(", "); + source.Append(additional.ToStringLiteral()); + } + } + + private static void AppendCommaSeparatedLiterals(StringBuilder source, EquatableImmutableArray values) + { + if (values.Count == 0) + return; + + source.Append(values[0] + .ToStringLiteral() + ); + for (var i = 1; i < values.Count; i++) + { + source.Append(", "); + source.Append(values[i] + .ToStringLiteral() + ); + } + } + + private static void AppendOptionalContentTypes(StringBuilder source, string? contentType, EquatableImmutableArray? additionalContentTypes) + { + if (string.IsNullOrEmpty(contentType) && additionalContentTypes is not { Count: > 0 }) + return; + + source.Append(", "); + source.Append(contentType is { Length: > 0 } ? contentType.ToStringLiteral() : "null"); + AppendAdditionalContentTypes(source, additionalContentTypes); + } +} diff --git a/tests/GeneratedEndpoints.Tests.Lab/GeneratedEndpoints.Tests.Lab.csproj b/tests/GeneratedEndpoints.Tests.Lab/GeneratedEndpoints.Tests.Lab.csproj index 7ab05ac..db2be7e 100644 --- a/tests/GeneratedEndpoints.Tests.Lab/GeneratedEndpoints.Tests.Lab.csproj +++ b/tests/GeneratedEndpoints.Tests.Lab/GeneratedEndpoints.Tests.Lab.csproj @@ -12,11 +12,11 @@ - + - - appsettings.json - + + appsettings.json + diff --git a/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs b/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs index 4452d06..31fd2cc 100644 --- a/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs +++ b/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.ComponentModel; +using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Generated.Attributes; @@ -31,11 +32,13 @@ internal sealed class GetUserEndpoint(IServiceProvider serviceProvider) [Description("Gets a user by ID when the ID is greater than zero.")] [Summary("Gets a user by ID.")] [MapGet("/users/{id:int}", Name = nameof(GetUser))] - public Results, NotFound, ValidationProblem, ProblemHttpResult> GetUser( - [FromQuery] int id, + public async ValueTask, NotFound, ValidationProblem, ProblemHttpResult>> GetUser( + [FromHeader(Name = "4")] int id, [FromKeyedServices(ServiceLifetime.Scoped)] IServiceCollection services ) { + await Task.Yield(); + if (id <= 0) { var errors = new Dictionary diff --git a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.cs b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.cs index 1831f55..9b27e5f 100644 --- a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.cs +++ b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.cs @@ -8,8 +8,8 @@ namespace GeneratedEndpoints.Tests; public class AttributeGenerationTests { private const string AttributeTestSource = "internal static class AttributeTestEndpoints { }"; - private static readonly GeneratorDriverRunResult GeneratorResult = - TestHelpers.RunGenerator(TestHelpers.GetSources(AttributeTestSource, withNamespace: true)); + + private static readonly GeneratorDriverRunResult GeneratorResult = TestHelpers.RunGenerator(TestHelpers.GetSources(AttributeTestSource, true)); public AttributeGenerationTests() { @@ -18,116 +18,174 @@ public AttributeGenerationTests() [Fact] public Task MapGetAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapGetAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapGetAttribute.gs.cs"); + } [Fact] public Task MapPostAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPostAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPostAttribute.gs.cs"); + } [Fact] public Task MapPutAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPutAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPutAttribute.gs.cs"); + } [Fact] public Task MapPatchAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPatchAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapPatchAttribute.gs.cs"); + } [Fact] public Task MapDeleteAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapDeleteAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapDeleteAttribute.gs.cs"); + } [Fact] public Task MapOptionsAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapOptionsAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapOptionsAttribute.gs.cs"); + } [Fact] public Task MapHeadAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapHeadAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapHeadAttribute.gs.cs"); + } [Fact] public Task MapQueryAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapQueryAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapQueryAttribute.gs.cs"); + } [Fact] public Task MapTraceAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapTraceAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapTraceAttribute.gs.cs"); + } [Fact] public Task MapConnectAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapConnectAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapConnectAttribute.gs.cs"); + } [Fact] public Task MapFallbackAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapFallbackAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapFallbackAttribute.gs.cs"); + } [Fact] public Task RequireAuthorizationAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireAuthorizationAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireAuthorizationAttribute.gs.cs"); + } [Fact] public Task RequireCorsAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireCorsAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireCorsAttribute.gs.cs"); + } [Fact] public Task RequireRateLimitingAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireRateLimitingAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireRateLimitingAttribute.gs.cs"); + } [Fact] public Task RequireHostAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireHostAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequireHostAttribute.gs.cs"); + } [Fact] public Task DisableAntiforgeryAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableAntiforgeryAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableAntiforgeryAttribute.gs.cs"); + } [Fact] public Task ShortCircuitAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ShortCircuitAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ShortCircuitAttribute.gs.cs"); + } [Fact] public Task DisableRequestTimeoutAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableRequestTimeoutAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableRequestTimeoutAttribute.gs.cs"); + } [Fact] public Task DisableValidationAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableValidationAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.DisableValidationAttribute.gs.cs"); + } [Fact] public Task RequestTimeoutAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequestTimeoutAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.RequestTimeoutAttribute.gs.cs"); + } [Fact] public Task OrderAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.OrderAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.OrderAttribute.gs.cs"); + } [Fact] public Task MapGroupAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapGroupAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.MapGroupAttribute.gs.cs"); + } [Fact] public Task SummaryAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.SummaryAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.SummaryAttribute.gs.cs"); + } [Fact] public Task AcceptsAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.AcceptsAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.AcceptsAttribute.gs.cs"); + } [Fact] public Task EndpointFilterAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.EndpointFilterAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.EndpointFilterAttribute.gs.cs"); + } [Fact] public Task ProducesResponseAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesResponseAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesResponseAttribute.gs.cs"); + } [Fact] public Task ProducesProblemAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesProblemAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesProblemAttribute.gs.cs"); + } [Fact] public Task ProducesValidationProblemAttribute() - => VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesValidationProblemAttribute.gs.cs"); + { + return VerifyAttributeAsync("Microsoft.AspNetCore.Generated.Attributes.ProducesValidationProblemAttribute.gs.cs"); + } private static Task VerifyAttributeAsync(string fileName) - => GeneratorResult.VerifyAsync(fileName); + { + return GeneratorResult.VerifyAsync(fileName); + } } diff --git a/tests/GeneratedEndpoints.Tests/Common/ScenarioNamer.cs b/tests/GeneratedEndpoints.Tests/Common/ScenarioNamer.cs index db7d771..27900d2 100644 --- a/tests/GeneratedEndpoints.Tests/Common/ScenarioNamer.cs +++ b/tests/GeneratedEndpoints.Tests/Common/ScenarioNamer.cs @@ -26,15 +26,13 @@ public static string Create(string prefix, params (string Name, object? Value)[] private static string Sanitize(object? value) { if (value is null) - { return "None"; - } return value switch { bool b => b ? "On" : "Off", string s => s, - _ => value.ToString() ?? "Value" + _ => value.ToString() ?? "Value", }; } } diff --git a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs index ac721c9..4fef5bd 100644 --- a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs +++ b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs @@ -53,29 +53,22 @@ public static string BuildAuthorizationMatrixSource( bool excludeFromDescription, string? mapGroupPattern = null, bool classDisableValidation = false, - bool methodDisableValidation = false) + bool methodDisableValidation = false + ) { var builder = new StringBuilder(); if (classAllowAnonymous) - { builder.AppendLine("[AllowAnonymous]"); - } if (classRequireAuthorization) - { builder.AppendLine("[RequireAuthorization(\"ClassPolicy\")]"); - } if (classTags) - { builder.AppendLine("[Tags(\"Class\", \"Matrix\")]"); - } if (!string.IsNullOrWhiteSpace(classHost)) - { builder.AppendLine($"[RequireHost(\"{classHost}\")]"); - } if (classRequireCors) { @@ -84,96 +77,70 @@ public static string BuildAuthorizationMatrixSource( } if (!string.IsNullOrWhiteSpace(groupName) && mapGroupPattern is null) - { - mapGroupPattern = string.Empty; - } + mapGroupPattern = ""; if (mapGroupPattern is not null) { var mapGroupAttribute = new StringBuilder(); mapGroupAttribute.Append($"[MapGroup(\"{mapGroupPattern}\""); if (!string.IsNullOrWhiteSpace(groupName)) - { mapGroupAttribute.Append($", Name = \"{groupName}\""); - } mapGroupAttribute.Append(")]"); builder.AppendLine(mapGroupAttribute.ToString()); } if (applyShortCircuit) - { builder.AppendLine("[ShortCircuit]"); - } if (classDisableValidation) - { builder.AppendLine("[DisableValidation]"); - } if (applyRequestTimeout) { - var timeoutArgument = string.IsNullOrWhiteSpace(requestTimeoutPolicy) - ? string.Empty - : $"(\"{requestTimeoutPolicy}\")"; + var timeoutArgument = string.IsNullOrWhiteSpace(requestTimeoutPolicy) ? "" : $"(\"{requestTimeoutPolicy}\")"; builder.AppendLine($"[RequestTimeout{timeoutArgument}]"); } if (disableRequestTimeout) - { builder.AppendLine("[DisableRequestTimeout]"); - } if (orderValue != 0) - { builder.AppendLine($"[Order({orderValue})]"); - } if (excludeFromDescription) - { builder.AppendLine("[ExcludeFromDescription]"); - } builder.AppendLine("internal sealed class AuthorizationMatrixEndpoints"); builder.AppendLine("{"); builder.AppendLine(" [MapGet(\"/matrix/{id:int}\", Name = \"GetMatrix\")]"); if (methodAllowAnonymous) - { builder.AppendLine(" [AllowAnonymous]"); - } if (methodRequireAuthorization) - { builder.AppendLine(" [RequireAuthorization(\"MethodPolicy\")]"); - } if (methodTags) - { builder.AppendLine(" [Tags(\"Method\", \"Matrix\")]"); - } if (!string.IsNullOrWhiteSpace(methodHost)) - { builder.AppendLine($" [RequireHost(\"{methodHost}\", \"contoso.com\")]"); - } if (methodRequireCors) { - var methodCors = string.IsNullOrWhiteSpace(methodCorsPolicy) ? string.Empty : $"(\"{methodCorsPolicy}\")"; + var methodCors = string.IsNullOrWhiteSpace(methodCorsPolicy) ? "" : $"(\"{methodCorsPolicy}\")"; builder.AppendLine($" [RequireCors{methodCors}]"); } if (requireRateLimiting) { - var rateLimit = string.IsNullOrWhiteSpace(rateLimitingPolicy) ? string.Empty : $"(\"{rateLimitingPolicy}\")"; + var rateLimit = string.IsNullOrWhiteSpace(rateLimitingPolicy) ? "" : $"(\"{rateLimitingPolicy}\")"; builder.AppendLine($" [RequireRateLimiting{rateLimit}]"); } if (methodDisableValidation) - { builder.AppendLine(" [DisableValidation]"); - } builder.AppendLine(" public static Ok Handle(int id) => id >= 0 ? TypedResults.Ok() : TypedResults.Ok();"); @@ -196,52 +163,44 @@ public static string BuildConfigureAndFiltersSource( bool includeMethodLevelFilter, bool includeGenericFilter, bool configureRegistersFilter, - string metadataValue) + string metadataValue + ) { var builder = new StringBuilder(); builder.AppendLine("using Microsoft.AspNetCore.Builder;"); builder.AppendLine(); if (includeClassLevelFilter) - { builder.AppendLine("[EndpointFilter(typeof(TimingFilter))]"); - } builder.AppendLine("internal static class ConfigureFilterEndpoints"); builder.AppendLine("{"); builder.AppendLine(" [MapGet(\"/configure-filters\")]"); if (includeMethodLevelFilter) - { builder.AppendLine(" [EndpointFilter(typeof(ValidationFilter))]"); - } if (includeGenericFilter) - { builder.AppendLine(" [EndpointFilter]"); - } builder.AppendLine(" public static Ok Handle() => TypedResults.Ok();"); builder.AppendLine(); - builder.AppendLine(" public static void Configure(TBuilder builder" + (configureWithServiceProvider ? ", IServiceProvider services" : string.Empty) + ")"); + builder.AppendLine(" public static void Configure(TBuilder builder" + + (configureWithServiceProvider ? ", IServiceProvider services" : "") + + ")" + ); builder.AppendLine(" where TBuilder : IEndpointConventionBuilder"); builder.AppendLine(" {"); builder.AppendLine(" _ = builder;"); if (configureWithServiceProvider) - { builder.AppendLine(" _ = services;"); - } if (configureAddsMetadata) - { builder.AppendLine($" builder.WithMetadata(\"{metadataValue}\");"); - } if (configureRegistersFilter) - { builder.AppendLine(" builder.AddEndpointFilterFactory((context, next) => next);"); - } builder.AppendLine(" }"); builder.AppendLine("}"); @@ -270,7 +229,8 @@ public static string BuildHttpMethodMatrixSource( bool includeQuery, bool includeTrace, bool includeConnect, - bool includeMethodNameCollision) + bool includeMethodNameCollision + ) { var builder = new StringBuilder(); builder.AppendLine("using Microsoft.AspNetCore.Mvc;"); @@ -279,54 +239,36 @@ public static string BuildHttpMethodMatrixSource( builder.AppendLine("{"); if (includeGet) - { builder.AppendLine(" [MapGet(\"/matrix\")] public static Ok Get() => TypedResults.Ok();"); - } if (includePost) - { builder.AppendLine(" [MapPost(\"/matrix\")] public static Created Post() => TypedResults.Created(\"/matrix/1\", \"Created\");"); - } if (includePut) - { - builder.AppendLine(" [MapPut(\"/matrix/{id:int}\")] public static Results Put(int id) => id > 0 ? TypedResults.NoContent() : TypedResults.NotFound();"); - } + builder.AppendLine( + " [MapPut(\"/matrix/{id:int}\")] public static Results Put(int id) => id > 0 ? TypedResults.NoContent() : TypedResults.NotFound();" + ); if (includeDelete) - { builder.AppendLine(" [MapDelete(\"/matrix/{id:int}\")] public static IResult Delete(int id) => TypedResults.Ok();"); - } if (includeOptions) - { builder.AppendLine(" [MapOptions(\"/matrix\")] public static IResult Options() => TypedResults.Ok();"); - } if (includeHead) - { builder.AppendLine(" [MapHead(\"/matrix\")] public static IResult Head() => TypedResults.Ok();"); - } if (includePatch) - { builder.AppendLine(" [MapPatch(\"/matrix/{id:int}\")] public static IResult Patch(int id) => TypedResults.Ok();"); - } if (includeQuery) - { builder.AppendLine(" [MapQuery(\"/matrix/query\")] public static IResult Query([FromQuery] string value) => TypedResults.Ok(value);"); - } if (includeTrace) - { builder.AppendLine(" [MapTrace(\"/matrix\")] public static IResult Trace() => TypedResults.Ok();"); - } if (includeConnect) - { builder.AppendLine(" [MapConnect(\"/matrix\")] public static IResult Connect() => TypedResults.Ok();"); - } builder.AppendLine("}"); @@ -343,6 +285,35 @@ public static string BuildHttpMethodMatrixSource( return builder.ToString(); } + public static string BuildEndpointNameCollisionSource() + { + return """ + internal static class AlphaEndpoints + { + [MapGet("/alpha/collision")] public static Ok Collision() => TypedResults.Ok("alpha-collision"); + [MapGet("/alpha/unique")] public static Ok UniqueAlpha() => TypedResults.Ok("unique-alpha"); + } + + internal static class BetaEndpoints + { + [MapGet("/beta/unique")] public static Ok UniqueBeta() => TypedResults.Ok("unique-beta"); + [MapGet("/beta/collision")] public static Ok Collision() => TypedResults.Ok("beta-collision"); + } + + internal static class GammaEndpoints + { + [MapGet("/gamma/collision")] public static Ok Collision() => TypedResults.Ok("gamma-collision"); + [MapGet("/gamma/unique")] public static Ok UniqueGamma() => TypedResults.Ok("unique-gamma"); + } + + internal static class DeltaEndpoints + { + [MapGet("/delta/unique")] public static Ok UniqueDelta() => TypedResults.Ok("unique-delta"); + [MapGet("/delta/collision")] public static Ok Collision() => TypedResults.Ok("delta-collision"); + } + """; + } + public static string BuildContractsAndBindingSource( bool includeBindingNames, bool includeAsParameters, @@ -362,7 +333,8 @@ public static string BuildContractsAndBindingSource( string? acceptsContentType1, string? acceptsContentType2, string? producesContentType1, - string? producesContentType2) + string? producesContentType2 + ) { var builder = new StringBuilder(); builder.AppendLine("using Microsoft.AspNetCore.Mvc;"); @@ -378,85 +350,59 @@ public static string BuildContractsAndBindingSource( } if (includeDisplayName) - { builder.AppendLine(" [DisplayName(\"Contract endpoint\")]"); - } if (includeTags) - { builder.AppendLine(" [Tags(\"Contracts\", \"Bindings\")]"); - } if (excludeFromDescription) - { builder.AppendLine(" [ExcludeFromDescription]"); - } if (allowAnonymous) - { builder.AppendLine(" [AllowAnonymous]"); - } if (methodRequiresAuthorization) - { builder.AppendLine(" [RequireAuthorization(\"ContractsPolicy\")]"); - } builder.AppendLine(" [MapGet(\"/contracts/{id:int}\")]"); if (includeAccepts) { - var secondContentType = string.IsNullOrWhiteSpace(acceptsContentType2) ? string.Empty : $", \"{acceptsContentType2}\""; + var secondContentType = string.IsNullOrWhiteSpace(acceptsContentType2) ? "" : $", \"{acceptsContentType2}\""; builder.AppendLine($" [Accepts(\"{acceptsContentType1 ?? "application/json"}\"{secondContentType})]"); } if (includeGenericAccepts) - { builder.AppendLine($" [Accepts(\"{acceptsContentType1 ?? "application/json"}\")]"); - } if (includeProducesResponse) { - var secondProduces = string.IsNullOrWhiteSpace(producesContentType2) ? string.Empty : $", \"{producesContentType2}\""; - builder.AppendLine($" [ProducesResponse(200, \"{producesContentType1 ?? "application/json"}\"{secondProduces}, ResponseType = typeof(ResponseRecord))]"); + var secondProduces = string.IsNullOrWhiteSpace(producesContentType2) ? "" : $", \"{producesContentType2}\""; + builder.AppendLine( + $" [ProducesResponse(200, \"{producesContentType1 ?? "application/json"}\"{secondProduces}, ResponseType = typeof(ResponseRecord))]" + ); } if (includeProducesProblem) - { builder.AppendLine($" [ProducesProblem(500, \"{producesContentType1 ?? "application/problem+json"}\")]"); - } if (includeProducesValidationProblem) - { builder.AppendLine($" [ProducesValidationProblem(422, \"{producesContentType1 ?? "application/problem+json"}\")]"); - } builder.AppendLine(" public static async Task, NotFound>> Handle("); - builder.AppendLine(includeBindingNames - ? " [FromRoute(Name = \"route-id\")] int id," - : " [FromRoute] int id,"); - builder.AppendLine(includeBindingNames - ? " [FromQuery(Name = \"filter-term\")] string? filter," - : " [FromQuery] string? filter,"); - builder.AppendLine(includeBindingNames - ? " [FromHeader(Name = \"x-trace-id\")] string? traceId," - : " [FromHeader] string? traceId,"); + builder.AppendLine(includeBindingNames ? " [FromRoute(Name = \"route-id\")] int id," : " [FromRoute] int id,"); + builder.AppendLine(includeBindingNames ? " [FromQuery(Name = \"filter-term\")] string? filter," : " [FromQuery] string? filter,"); + builder.AppendLine(includeBindingNames ? " [FromHeader(Name = \"x-trace-id\")] string? traceId," : " [FromHeader] string? traceId,"); builder.AppendLine(" [FromBody] RequestRecord request,"); if (includeAsParameters) - { builder.AppendLine(" [AsParameters] AdditionalParameters parameters,"); - } if (includeFromServices) - { builder.AppendLine(" [FromServices] IServiceProvider services,"); - } if (includeFromKeyedServices) - { builder.AppendLine(" [FromKeyedServices(\"special\")] object keyed,"); - } builder.AppendLine(" CancellationToken cancellationToken)"); builder.AppendLine(" {"); @@ -470,9 +416,7 @@ public static string BuildContractsAndBindingSource( builder.AppendLine("internal sealed record ResponseRecord(int Value);"); if (includeAsParameters) - { builder.AppendLine("internal sealed record AdditionalParameters(string? Search, int? Page);"); - } return builder.ToString(); } diff --git a/tests/GeneratedEndpoints.Tests/Common/TestHelpers.cs b/tests/GeneratedEndpoints.Tests/Common/TestHelpers.cs index 8771b24..e4a4dfb 100644 --- a/tests/GeneratedEndpoints.Tests/Common/TestHelpers.cs +++ b/tests/GeneratedEndpoints.Tests/Common/TestHelpers.cs @@ -11,7 +11,8 @@ public static GeneratorDriverRunResult RunGenerator(IEnumerable sources) { var cSharpParseOptions = new CSharpParseOptions(LanguageVersion.CSharp11).WithPreprocessorSymbols("NET10_0_OR_GREATER"); var cSharpCompilationOptions = new CSharpCompilationOptions(OutputKind.NetModule).WithNullableContextOptions(NullableContextOptions.Enable); - var (_, result) = IncrementalGenerator.RunWithDiagnostics(sources, cSharpParseOptions, AspNet100.References.All, cSharpCompilationOptions); + var (_, result) = + IncrementalGenerator.RunWithDiagnostics(sources, cSharpParseOptions, AspNet100.References.All, cSharpCompilationOptions); return result; } diff --git a/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj b/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj index 1f74d48..1bd4374 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj +++ b/tests/GeneratedEndpoints.Tests/GeneratedEndpoints.Tests.csproj @@ -38,4 +38,22 @@ + + + GeneratedSourceTests.cs + + + IndividualTests.cs + + + GeneratedSourceTests.cs + + + GeneratedSourceTests.cs + + + GeneratedSourceTests.cs + + + diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_25B08C7DE832_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_25B08C7DE832_MapEndpointHandlers.verified.txt index d1abab8..d95bea0 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_25B08C7DE832_MapEndpointHandlers.verified.txt +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_25B08C7DE832_MapEndpointHandlers.verified.txt @@ -24,8 +24,8 @@ internal static class EndpointRouteBuilderExtensions { builder.MapGet("/configure-filters", global::GeneratedEndpointsTests.ConfigureFilterEndpoints.Handle) .WithName("Handle") - .AddEndpointFilter() - .AddEndpointFilter(); + .AddEndpointFilter() + .AddEndpointFilter(); return builder; } diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_401A05F2C177_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_401A05F2C177_MapEndpointHandlers.verified.txt index d1abab8..d95bea0 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_401A05F2C177_MapEndpointHandlers.verified.txt +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ConfigureAndFiltersMatrix_401A05F2C177_MapEndpointHandlers.verified.txt @@ -24,8 +24,8 @@ internal static class EndpointRouteBuilderExtensions { builder.MapGet("/configure-filters", global::GeneratedEndpointsTests.ConfigureFilterEndpoints.Handle) .WithName("Handle") - .AddEndpointFilter() - .AddEndpointFilter(); + .AddEndpointFilter() + .AddEndpointFilter(); return builder; } diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs index 430d05b..dd38cbb 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs @@ -21,11 +21,9 @@ public async Task MapFallbackScenarios(bool withNamespace, bool includeDefaultFa { var sources = TestHelpers.GetSources(SourceFactory.BuildFallbackSource(includeDefaultFallback, includeCustomFallback, customRoute), withNamespace); var result = TestHelpers.RunGenerator(sources); - var scenario = ScenarioNamer.Create(nameof(MapFallbackScenarios), - ("Namespace", withNamespace), - ("Default", includeDefaultFallback), - ("Custom", includeCustomFallback), - ("Route", customRoute ?? "default")); + var scenario = ScenarioNamer.Create(nameof(MapFallbackScenarios), ("Namespace", withNamespace), ("Default", includeDefaultFallback), + ("Custom", includeCustomFallback), ("Route", customRoute ?? "default") + ); await result.VerifyAsync("AddEndpointHandlers.g.cs") .UseMethodName($"{scenario}_AddEndpointHandlers"); @@ -35,11 +33,21 @@ await result.VerifyAsync("MapEndpointHandlers.g.cs") } [Theory] - [InlineData(true, true, false, true, false, true, true, "*.contoso.com", "api.contoso.com", true, "NamedCorsPolicy", false, null, true, "RatePolicy", true, true, "TimeoutPolicy", false, 5, "Reporting", true)] - [InlineData(false, false, true, false, true, false, true, null, "services.contoso.com", false, null, true, "MethodCors", true, null, false, false, null, true, -1, null, false)] - [InlineData(true, true, true, true, true, true, false, "*.example.com", null, true, null, true, null, false, null, false, true, null, true, 0, "Operations", true)] - [InlineData(false, false, false, true, false, true, false, null, "*.alt.com", false, "CorsDefault", false, null, false, null, true, false, null, false, 10, null, false)] - [InlineData(true, false, true, false, true, false, true, "api.alt.com", null, true, null, true, "MethodCors", true, "BurstPolicy", false, true, "TimeoutPolicy", true, -5, "Docs", true)] + [InlineData(true, true, false, true, false, true, true, "*.contoso.com", "api.contoso.com", true, "NamedCorsPolicy", false, null, true, "RatePolicy", true, + true, "TimeoutPolicy", false, 5, "Reporting", true + )] + [InlineData(false, false, true, false, true, false, true, null, "services.contoso.com", false, null, true, "MethodCors", true, null, false, false, null, + true, -1, null, false + )] + [InlineData(true, true, true, true, true, true, false, "*.example.com", null, true, null, true, null, false, null, false, true, null, true, 0, "Operations", + true + )] + [InlineData(false, false, false, true, false, true, false, null, "*.alt.com", false, "CorsDefault", false, null, false, null, true, false, null, false, 10, + null, false + )] + [InlineData(true, false, true, false, true, false, true, "api.alt.com", null, true, null, true, "MethodCors", true, "BurstPolicy", false, true, + "TimeoutPolicy", true, -5, "Docs", true + )] public async Task AuthorizationAndMetadataMatrix( bool withNamespace, bool classAllowAnonymous, @@ -62,52 +70,24 @@ public async Task AuthorizationAndMetadataMatrix( bool disableRequestTimeout, int orderValue, string? groupName, - bool excludeFromDescription) + bool excludeFromDescription + ) { - var source = SourceFactory.BuildAuthorizationMatrixSource( - classAllowAnonymous, - methodAllowAnonymous, - classRequireAuthorization, - methodRequireAuthorization, - classTags, - methodTags, - classHost, - methodHost, - classRequireCors, - classCorsPolicy, - methodRequireCors, - methodCorsPolicy, - requireRateLimiting, - rateLimitingPolicy, - applyShortCircuit, - applyRequestTimeout, - requestTimeoutPolicy, - disableRequestTimeout, - orderValue, - groupName, - excludeFromDescription); + var source = SourceFactory.BuildAuthorizationMatrixSource(classAllowAnonymous, methodAllowAnonymous, classRequireAuthorization, + methodRequireAuthorization, classTags, methodTags, classHost, methodHost, classRequireCors, classCorsPolicy, methodRequireCors, methodCorsPolicy, + requireRateLimiting, rateLimitingPolicy, applyShortCircuit, applyRequestTimeout, requestTimeoutPolicy, disableRequestTimeout, orderValue, groupName, + excludeFromDescription + ); var sources = TestHelpers.GetSources(source, withNamespace); var result = TestHelpers.RunGenerator(sources); - var scenario = ScenarioNamer.Create(nameof(AuthorizationAndMetadataMatrix), - ("Namespace", withNamespace), - ("ClassAnon", classAllowAnonymous), - ("MethodAnon", methodAllowAnonymous), - ("ClassAuth", classRequireAuthorization), - ("MethodAuth", methodRequireAuthorization), - ("ClassTags", classTags), - ("MethodTags", methodTags), - ("ClassHost", classHost ?? "none"), - ("MethodHost", methodHost ?? "none"), - ("ClassCors", classRequireCors), - ("MethodCors", methodRequireCors), - ("RateLimit", requireRateLimiting), - ("ShortCircuit", applyShortCircuit), - ("RequestTimeout", applyRequestTimeout), - ("DisableTimeout", disableRequestTimeout), - ("Order", orderValue), - ("Group", groupName ?? "none"), - ("Exclude", excludeFromDescription)); + var scenario = ScenarioNamer.Create(nameof(AuthorizationAndMetadataMatrix), ("Namespace", withNamespace), ("ClassAnon", classAllowAnonymous), + ("MethodAnon", methodAllowAnonymous), ("ClassAuth", classRequireAuthorization), ("MethodAuth", methodRequireAuthorization), + ("ClassTags", classTags), ("MethodTags", methodTags), ("ClassHost", classHost ?? "none"), ("MethodHost", methodHost ?? "none"), + ("ClassCors", classRequireCors), ("MethodCors", methodRequireCors), ("RateLimit", requireRateLimiting), ("ShortCircuit", applyShortCircuit), + ("RequestTimeout", applyRequestTimeout), ("DisableTimeout", disableRequestTimeout), ("Order", orderValue), ("Group", groupName ?? "none"), + ("Exclude", excludeFromDescription) + ); await result.VerifyAsync("AddEndpointHandlers.g.cs") .UseMethodName($"{scenario}_AddEndpointHandlers"); @@ -130,28 +110,19 @@ public async Task ConfigureAndFiltersMatrix( bool includeMethodLevelFilter, bool includeGenericFilter, bool configureRegistersFilter, - string metadataValue) + string metadataValue + ) { - var source = SourceFactory.BuildConfigureAndFiltersSource( - configureWithServiceProvider, - configureAddsMetadata, - includeClassLevelFilter, - includeMethodLevelFilter, - includeGenericFilter, - configureRegistersFilter, - metadataValue); + var source = SourceFactory.BuildConfigureAndFiltersSource(configureWithServiceProvider, configureAddsMetadata, includeClassLevelFilter, + includeMethodLevelFilter, includeGenericFilter, configureRegistersFilter, metadataValue + ); var sources = TestHelpers.GetSources(source, withNamespace); var result = TestHelpers.RunGenerator(sources); - var scenario = ScenarioNamer.Create(nameof(ConfigureAndFiltersMatrix), - ("Namespace", withNamespace), - ("SvcProvider", configureWithServiceProvider), - ("Metadata", configureAddsMetadata), - ("ClassFilter", includeClassLevelFilter), - ("MethodFilter", includeMethodLevelFilter), - ("GenericFilter", includeGenericFilter), - ("ConfigureFilter", configureRegistersFilter), - ("Value", metadataValue)); + var scenario = ScenarioNamer.Create(nameof(ConfigureAndFiltersMatrix), ("Namespace", withNamespace), ("SvcProvider", configureWithServiceProvider), + ("Metadata", configureAddsMetadata), ("ClassFilter", includeClassLevelFilter), ("MethodFilter", includeMethodLevelFilter), + ("GenericFilter", includeGenericFilter), ("ConfigureFilter", configureRegistersFilter), ("Value", metadataValue) + ); await result.VerifyAsync("AddEndpointHandlers.g.cs") .UseMethodName($"{scenario}_AddEndpointHandlers"); @@ -178,36 +149,19 @@ public async Task HttpMethodMatrix( bool includeQuery, bool includeTrace, bool includeConnect, - bool includeMethodNameCollision) + bool includeMethodNameCollision + ) { - var source = SourceFactory.BuildHttpMethodMatrixSource( - includeGet, - includePost, - includePut, - includeDelete, - includeOptions, - includeHead, - includePatch, - includeQuery, - includeTrace, - includeConnect, - includeMethodNameCollision); + var source = SourceFactory.BuildHttpMethodMatrixSource(includeGet, includePost, includePut, includeDelete, includeOptions, includeHead, includePatch, + includeQuery, includeTrace, includeConnect, includeMethodNameCollision + ); var sources = TestHelpers.GetSources(source, withNamespace); var result = TestHelpers.RunGenerator(sources); - var scenario = ScenarioNamer.Create(nameof(HttpMethodMatrix), - ("Namespace", withNamespace), - ("Get", includeGet), - ("Post", includePost), - ("Put", includePut), - ("Delete", includeDelete), - ("Options", includeOptions), - ("Head", includeHead), - ("Patch", includePatch), - ("Query", includeQuery), - ("Trace", includeTrace), - ("Connect", includeConnect), - ("Collision", includeMethodNameCollision)); + var scenario = ScenarioNamer.Create(nameof(HttpMethodMatrix), ("Namespace", withNamespace), ("Get", includeGet), ("Post", includePost), + ("Put", includePut), ("Delete", includeDelete), ("Options", includeOptions), ("Head", includeHead), ("Patch", includePatch), + ("Query", includeQuery), ("Trace", includeTrace), ("Connect", includeConnect), ("Collision", includeMethodNameCollision) + ); await result.VerifyAsync("AddEndpointHandlers.g.cs") .UseMethodName($"{scenario}_AddEndpointHandlers"); @@ -217,10 +171,16 @@ await result.VerifyAsync("MapEndpointHandlers.g.cs") } [Theory] - [InlineData(true, true, true, true, true, true, true, true, true, true, true, true, true, false, true, false, "application/xml", "text/xml", "application/json", "text/json")] - [InlineData(false, false, true, false, false, true, false, true, true, false, false, false, true, true, false, true, "application/custom", null, "application/problem+json", null)] + [InlineData(true, true, true, true, true, true, true, true, true, true, true, true, true, false, true, false, "application/xml", "text/xml", + "application/json", "text/json" + )] + [InlineData(false, false, true, false, false, true, false, true, true, false, false, false, true, true, false, true, "application/custom", null, + "application/problem+json", null + )] [InlineData(true, true, false, true, true, false, true, false, false, true, true, true, false, false, true, true, null, null, null, null)] - [InlineData(false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, "application/xml", null, "application/json", null)] + [InlineData(false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, "application/xml", null, + "application/json", null + )] [InlineData(true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, null, "text/plain", null, "text/plain")] public async Task ContractsAndBindingMatrix( bool withNamespace, @@ -242,48 +202,23 @@ public async Task ContractsAndBindingMatrix( string? acceptsContentType1, string? acceptsContentType2, string? producesContentType1, - string? producesContentType2) + string? producesContentType2 + ) { - var source = SourceFactory.BuildContractsAndBindingSource( - includeBindingNames, - includeAsParameters, - includeFromServices, - includeFromKeyedServices, - includeAccepts, - includeGenericAccepts, - includeProducesResponse, - includeProducesProblem, - includeProducesValidationProblem, - includeSummaryAndDescription, - includeDisplayName, - includeTags, - excludeFromDescription, - allowAnonymous, - methodRequiresAuthorization, - acceptsContentType1, - acceptsContentType2, - producesContentType1, - producesContentType2); + var source = SourceFactory.BuildContractsAndBindingSource(includeBindingNames, includeAsParameters, includeFromServices, includeFromKeyedServices, + includeAccepts, includeGenericAccepts, includeProducesResponse, includeProducesProblem, includeProducesValidationProblem, + includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, methodRequiresAuthorization, + acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 + ); var sources = TestHelpers.GetSources(source, withNamespace); var result = TestHelpers.RunGenerator(sources); - var scenario = ScenarioNamer.Create(nameof(ContractsAndBindingMatrix), - ("Namespace", withNamespace), - ("BindingNames", includeBindingNames), - ("AsParameters", includeAsParameters), - ("Services", includeFromServices), - ("KeyedServices", includeFromKeyedServices), - ("Accepts", includeAccepts), - ("GenericAccepts", includeGenericAccepts), - ("Produces", includeProducesResponse), - ("Problem", includeProducesProblem), - ("Validation", includeProducesValidationProblem), - ("Summary", includeSummaryAndDescription), - ("DisplayName", includeDisplayName), - ("Tags", includeTags), - ("Exclude", excludeFromDescription), - ("AllowAnon", allowAnonymous), - ("MethodAuth", methodRequiresAuthorization)); + var scenario = ScenarioNamer.Create(nameof(ContractsAndBindingMatrix), ("Namespace", withNamespace), ("BindingNames", includeBindingNames), + ("AsParameters", includeAsParameters), ("Services", includeFromServices), ("KeyedServices", includeFromKeyedServices), ("Accepts", includeAccepts), + ("GenericAccepts", includeGenericAccepts), ("Produces", includeProducesResponse), ("Problem", includeProducesProblem), + ("Validation", includeProducesValidationProblem), ("Summary", includeSummaryAndDescription), ("DisplayName", includeDisplayName), + ("Tags", includeTags), ("Exclude", excludeFromDescription), ("AllowAnon", allowAnonymous), ("MethodAuth", methodRequiresAuthorization) + ); await result.VerifyAsync("AddEndpointHandlers.g.cs") .UseMethodName($"{scenario}_AddEndpointHandlers"); diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.AsyncMethodVariants_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/IndividualTests.AsyncMethodVariants_MapEndpointHandlers.verified.txt index 2d957e5..93fa041 100644 --- a/tests/GeneratedEndpoints.Tests/IndividualTests.AsyncMethodVariants_MapEndpointHandlers.verified.txt +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.AsyncMethodVariants_MapEndpointHandlers.verified.txt @@ -25,13 +25,13 @@ internal static class EndpointRouteBuilderExtensions builder.MapGet("/task", static ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler) => handler.TaskOnly()) .WithName("TaskOnly"); - builder.MapGet("/task-result", static async ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler, int id) => await handler.TaskWithResult(id)) + builder.MapGet("/task-result", static ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler, int id) => handler.TaskWithResult(id)) .WithName("TaskWithResult"); builder.MapPost("/valuetask", static ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler) => handler.ValueTaskOnly()) .WithName("ValueTaskOnly"); - builder.MapPost("/valuetask-result", static async ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler, int id) => await handler.ValueTaskWithResult(id)) + builder.MapPost("/valuetask-result", static ([FromServices] global::GeneratedEndpointsTests.AsyncHandlerEndpoints handler, int id) => handler.ValueTaskWithResult(id)) .WithName("ValueTaskWithResult"); return builder; diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_AddEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_AddEndpointHandlers.verified.txt new file mode 100644 index 0000000..64b4329 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_AddEndpointHandlers.verified.txt @@ -0,0 +1,23 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointServicesExtensions +{ + internal static void AddEndpointHandlers(this IServiceCollection services) + { + } +} diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_MapEndpointHandlers.verified.txt new file mode 100644 index 0000000..8950dd4 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.MultipleEndpointNameCollisions_MapEndpointHandlers.verified.txt @@ -0,0 +1,51 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointRouteBuilderExtensions +{ + internal static IEndpointRouteBuilder MapEndpointHandlers(this IEndpointRouteBuilder builder) + { + builder.MapGet("/alpha/collision", global::GeneratedEndpointsTests.AlphaEndpoints.Collision) + .WithName("GeneratedEndpointsTests.AlphaEndpoints.Collision"); + + builder.MapGet("/alpha/unique", global::GeneratedEndpointsTests.AlphaEndpoints.UniqueAlpha) + .WithName("UniqueAlpha"); + + builder.MapGet("/beta/collision", global::GeneratedEndpointsTests.BetaEndpoints.Collision) + .WithName("GeneratedEndpointsTests.BetaEndpoints.Collision"); + + builder.MapGet("/beta/unique", global::GeneratedEndpointsTests.BetaEndpoints.UniqueBeta) + .WithName("UniqueBeta"); + + builder.MapGet("/delta/collision", global::GeneratedEndpointsTests.DeltaEndpoints.Collision) + .WithName("GeneratedEndpointsTests.DeltaEndpoints.Collision"); + + builder.MapGet("/delta/unique", global::GeneratedEndpointsTests.DeltaEndpoints.UniqueDelta) + .WithName("UniqueDelta"); + + builder.MapGet("/gamma/collision", global::GeneratedEndpointsTests.GammaEndpoints.Collision) + .WithName("GeneratedEndpointsTests.GammaEndpoints.Collision"); + + builder.MapGet("/gamma/unique", global::GeneratedEndpointsTests.GammaEndpoints.UniqueGamma) + .WithName("UniqueGamma"); + + return builder; + } +} diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.cs b/tests/GeneratedEndpoints.Tests/IndividualTests.cs index e6c0265..be6ee72 100644 --- a/tests/GeneratedEndpoints.Tests/IndividualTests.cs +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.cs @@ -14,7 +14,7 @@ public IndividualTests() [Fact] public async Task DefaultFallbackOnly() { - var source = FallbackScenario(includeDefault: true); + var source = FallbackScenario(true); await VerifyIndividualAsync(source, nameof(DefaultFallbackOnly)); } @@ -28,7 +28,7 @@ public async Task CustomFallbackRoute() [Fact] public async Task ClassAllowAnonymous() { - var source = AuthorizationScenario(classAllowAnonymous: true); + var source = AuthorizationScenario(true); await VerifyIndividualAsync(source, nameof(ClassAllowAnonymous)); } @@ -173,32 +173,20 @@ public async Task OrderMetadata() } [Fact] - public async Task GroupName() + public async Task ClassMapGroup() { - var source = AuthorizationScenario(groupName: "IndividualGroup"); - await VerifyIndividualAsync(source, nameof(GroupName)); + var source = AuthorizationScenario(classRequireAuthorization: true, classTags: true, classHost: "*.individual.com", classRequireCors: true, + classCorsPolicy: "ClassCors", applyShortCircuit: true, applyRequestTimeout: true, requestTimeoutPolicy: "ClassTimeout", orderValue: 2, + groupName: "ClassGroup", excludeFromDescription: true, methodAllowAnonymous: true, methodTags: true, mapGroupPattern: "/individuals" + ); + await VerifyIndividualAsync(source, nameof(ClassMapGroup)); } [Fact] - public async Task ClassMapGroup() + public async Task GroupName() { - var source = AuthorizationScenario( - classRequireAuthorization: true, - classTags: true, - classHost: "*.individual.com", - classRequireCors: true, - classCorsPolicy: "ClassCors", - applyShortCircuit: true, - applyRequestTimeout: true, - requestTimeoutPolicy: "ClassTimeout", - orderValue: 2, - groupName: "ClassGroup", - excludeFromDescription: true, - methodAllowAnonymous: true, - methodTags: true, - mapGroupPattern: "/individuals" - ); - await VerifyIndividualAsync(source, nameof(ClassMapGroup)); + var source = AuthorizationScenario(groupName: "IndividualGroup"); + await VerifyIndividualAsync(source, nameof(GroupName)); } [Fact] @@ -232,7 +220,7 @@ public async Task GenericEndpointFilter() [Fact] public async Task ConfigureWithServiceProvider() { - var source = ConfigureScenario(configureWithServiceProvider: true); + var source = ConfigureScenario(true); await VerifyIndividualAsync(source, nameof(ConfigureWithServiceProvider)); } @@ -253,7 +241,7 @@ public async Task ConfigureRegistersFilter() [Fact] public async Task MapGetEndpoint() { - var source = HttpMethodScenario(includeGet: true); + var source = HttpMethodScenario(true); await VerifyIndividualAsync(source, nameof(MapGetEndpoint)); } @@ -323,14 +311,21 @@ public async Task MapConnectEndpoint() [Fact] public async Task MethodNameCollision() { - var source = HttpMethodScenario(includeGet: true, includeMethodNameCollision: true); + var source = HttpMethodScenario(true, includeMethodNameCollision: true); await VerifyIndividualAsync(source, nameof(MethodNameCollision)); } + [Fact] + public async Task MultipleEndpointNameCollisions() + { + var source = EndpointNameCollisionScenario(); + await VerifyIndividualAsync(source, nameof(MultipleEndpointNameCollisions)); + } + [Fact] public async Task BindingNames() { - var source = ContractScenario(includeBindingNames: true); + var source = ContractScenario(true); await VerifyIndividualAsync(source, nameof(BindingNames)); } @@ -466,7 +461,9 @@ await result.VerifyAsync("MapEndpointHandlers.g.cs") } private static string FallbackScenario(bool includeDefault = false, bool includeCustom = false, string? customRoute = null) - => SourceFactory.BuildFallbackSource(includeDefault, includeCustom, customRoute); + { + return SourceFactory.BuildFallbackSource(includeDefault, includeCustom, customRoute); + } private static string AuthorizationScenario( bool classAllowAnonymous = false, @@ -492,32 +489,15 @@ private static string AuthorizationScenario( bool excludeFromDescription = false, string? mapGroupPattern = null, bool classDisableValidation = false, - bool methodDisableValidation = false) - => SourceFactory.BuildAuthorizationMatrixSource( - classAllowAnonymous, - methodAllowAnonymous, - classRequireAuthorization, - methodRequireAuthorization, - classTags, - methodTags, - classHost, - methodHost, - classRequireCors, - classCorsPolicy, - methodRequireCors, - methodCorsPolicy, - requireRateLimiting, - rateLimitingPolicy, - applyShortCircuit, - applyRequestTimeout, - requestTimeoutPolicy, - disableRequestTimeout, - orderValue, - groupName, - excludeFromDescription, - mapGroupPattern, - classDisableValidation, - methodDisableValidation); + bool methodDisableValidation = false + ) + { + return SourceFactory.BuildAuthorizationMatrixSource(classAllowAnonymous, methodAllowAnonymous, classRequireAuthorization, methodRequireAuthorization, + classTags, methodTags, classHost, methodHost, classRequireCors, classCorsPolicy, methodRequireCors, methodCorsPolicy, requireRateLimiting, + rateLimitingPolicy, applyShortCircuit, applyRequestTimeout, requestTimeoutPolicy, disableRequestTimeout, orderValue, groupName, + excludeFromDescription, mapGroupPattern, classDisableValidation, methodDisableValidation + ); + } private static string ConfigureScenario( bool configureWithServiceProvider = false, @@ -526,15 +506,13 @@ private static string ConfigureScenario( bool includeMethodLevelFilter = false, bool includeGenericFilter = false, bool configureRegistersFilter = false, - string metadataValue = "Individual") - => SourceFactory.BuildConfigureAndFiltersSource( - configureWithServiceProvider, - configureAddsMetadata, - includeClassLevelFilter, - includeMethodLevelFilter, - includeGenericFilter, - configureRegistersFilter, - metadataValue); + string metadataValue = "Individual" + ) + { + return SourceFactory.BuildConfigureAndFiltersSource(configureWithServiceProvider, configureAddsMetadata, includeClassLevelFilter, + includeMethodLevelFilter, includeGenericFilter, configureRegistersFilter, metadataValue + ); + } private static string HttpMethodScenario( bool includeGet = false, @@ -547,19 +525,18 @@ private static string HttpMethodScenario( bool includeQuery = false, bool includeTrace = false, bool includeConnect = false, - bool includeMethodNameCollision = false) - => SourceFactory.BuildHttpMethodMatrixSource( - includeGet, - includePost, - includePut, - includeDelete, - includeOptions, - includeHead, - includePatch, - includeQuery, - includeTrace, - includeConnect, - includeMethodNameCollision); + bool includeMethodNameCollision = false + ) + { + return SourceFactory.BuildHttpMethodMatrixSource(includeGet, includePost, includePut, includeDelete, includeOptions, includeHead, includePatch, + includeQuery, includeTrace, includeConnect, includeMethodNameCollision + ); + } + + private static string EndpointNameCollisionScenario() + { + return SourceFactory.BuildEndpointNameCollisionSource(); + } private static string ContractScenario( bool includeBindingNames = false, @@ -580,59 +557,49 @@ private static string ContractScenario( string? acceptsContentType1 = null, string? acceptsContentType2 = null, string? producesContentType1 = null, - string? producesContentType2 = null) - => SourceFactory.BuildContractsAndBindingSource( - includeBindingNames, - includeAsParameters, - includeFromServices, - includeFromKeyedServices, - includeAccepts, - includeGenericAccepts, - includeProducesResponse, - includeProducesProblem, - includeProducesValidationProblem, - includeSummaryAndDescription, - includeDisplayName, - includeTags, - excludeFromDescription, - allowAnonymous, - methodRequiresAuthorization, - acceptsContentType1, - acceptsContentType2, - producesContentType1, - producesContentType2); + string? producesContentType2 = null + ) + { + return SourceFactory.BuildContractsAndBindingSource(includeBindingNames, includeAsParameters, includeFromServices, includeFromKeyedServices, + includeAccepts, includeGenericAccepts, includeProducesResponse, includeProducesProblem, includeProducesValidationProblem, + includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, methodRequiresAuthorization, + acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 + ); + } private static string AsyncHandlerScenario() - => """ - using System.Threading.Tasks; - - internal sealed class AsyncHandlerEndpoints - { - [MapGet("/task")] - public async Task TaskOnly() - { - await Task.Yield(); - } - - [MapGet("/task-result")] - public async Task, NotFound>> TaskWithResult(int id) - { - await Task.Yield(); - return id >= 0 ? TypedResults.Ok("task") : TypedResults.NotFound(); - } - - [MapPost("/valuetask")] - public async ValueTask ValueTaskOnly() - { - await Task.Yield(); - } - - [MapPost("/valuetask-result")] - public async ValueTask, NotFound>> ValueTaskWithResult(int id) - { - await Task.Yield(); - return id >= 0 ? TypedResults.Ok("value") : TypedResults.NotFound(); - } - } - """; + { + return """ + using System.Threading.Tasks; + + internal sealed class AsyncHandlerEndpoints + { + [MapGet("/task")] + public async Task TaskOnly() + { + await Task.Yield(); + } + + [MapGet("/task-result")] + public async Task, NotFound>> TaskWithResult(int id) + { + await Task.Yield(); + return id >= 0 ? TypedResults.Ok("task") : TypedResults.NotFound(); + } + + [MapPost("/valuetask")] + public async ValueTask ValueTaskOnly() + { + await Task.Yield(); + } + + [MapPost("/valuetask-result")] + public async ValueTask, NotFound>> ValueTaskWithResult(int id) + { + await Task.Yield(); + return id >= 0 ? TypedResults.Ok("value") : TypedResults.NotFound(); + } + } + """; + } }