diff --git a/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs index 2da92c7..6a6551f 100644 --- a/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs +++ b/src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs @@ -1,4 +1,5 @@ -using System.Text; +using System.Collections.Immutable; +using System.Text; using GeneratedEndpoints.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Text; @@ -12,12 +13,15 @@ namespace GeneratedEndpoints; internal static class AddEndpointHandlersGenerator { - public static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) + public static void GenerateSource(SourceProductionContext context, ImmutableSortedDictionary> grouped) { context.CancellationToken.ThrowIfCancellationRequested(); - var nonStaticClassNames = GetDistinctNonStaticClassNames(requestHandlers); - var source = GetAddEndpointHandlersStringBuilder(nonStaticClassNames); + var nonStaticClassNames = grouped.Keys + .Where(x => !x.IsStatic) + .Select(x => x.Name) + .ToList(); + var source = new StringBuilder(); source.AppendLine(FileHeader); source.AppendLine(); @@ -61,45 +65,4 @@ public static void GenerateSource(SourceProductionContext context, EquatableImmu var sourceText = StringBuilderPool.ToStringAndReturn(source); context.AddSource(AddEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); } - - private static List GetDistinctNonStaticClassNames(EquatableImmutableArray requestHandlers) - { - var classNames = new List(); - if (requestHandlers.Count == 0) - return classNames; - - var seen = new HashSet(StringComparer.Ordinal); - for (var index = 0; index < requestHandlers.Count; index++) - { - var requestHandler = requestHandlers[index]; - if (requestHandler.Class.IsStatic) - continue; - - var className = requestHandler.Class.Name; - if (seen.Add(className)) - classNames.Add(className); - } - - return classNames; - } - - private static StringBuilder GetAddEndpointHandlersStringBuilder(List nonStaticClassNames) - { - var estimate = 512L; - for (var index = 0; index < nonStaticClassNames.Count; index++) - { - var className = nonStaticClassNames[index]; - estimate += 36 + className.Length; - } - - estimate += Math.Max(256, nonStaticClassNames.Count * 12); - estimate = (long)(estimate * 1.10); - - if (estimate < 512) - estimate = 512; - else if (estimate > int.MaxValue) - estimate = int.MaxValue; - - return StringBuilderPool.Get((int)estimate); - } } diff --git a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs index 1a10e01..9dab472 100644 --- a/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs +++ b/src/GeneratedEndpoints/Common/AttributeDataExtensions.cs @@ -56,6 +56,14 @@ internal static class AttributeDataExtensions return null; } + public static ITypeSymbol? GetConstructorTypeSymbol(this AttributeData attribute, int position = 0) + { + if (attribute.ConstructorArguments.Length > position && attribute.ConstructorArguments[position].Value is ITypeSymbol typeSymbol) + return typeSymbol; + + return null; + } + public static ITypeSymbol? GetNamedTypeSymbol(this AttributeData attribute, string namedParameter) { foreach (var namedArg in attribute.NamedArguments) diff --git a/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs index 6b6d8e5..1b90733 100644 --- a/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs +++ b/src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs @@ -370,14 +370,9 @@ namespace {{AttributesNamespace}}; internal sealed class {{AcceptsAttributeName}} : global::System.Attribute { /// - /// Gets the request type accepted by the endpoint. + /// Gets the CLR type of the endpoint filter. /// - public global::System.Type? RequestType { get; init; } - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } + public global::System.Type Type { get; } /// /// Gets the primary content type accepted by the endpoint. @@ -389,15 +384,22 @@ internal sealed class {{AcceptsAttributeName}} : global::System.Attribute /// public string[] AdditionalContentTypes { get; } + /// + /// Gets a value indicating whether the request body is optional. + /// + public bool IsOptional { get; init; } + /// /// Initializes a new instance of the class. /// + /// The CLR type of the request body. /// The primary content type accepted by the endpoint. /// Additional content types accepted by the endpoint. - public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) + public {{AcceptsAttributeName}}(global::System.Type type, string contentType = "application/json", params string[] additionalContentTypes) { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; + Type = type; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes; } } @@ -409,14 +411,9 @@ internal sealed class {{AcceptsAttributeName}} : global::System.Attribute internal sealed class {{AcceptsAttributeName}} : global::System.Attribute { /// - /// Gets the request type accepted by the endpoint. + /// Gets the CLR type of the endpoint filter. /// - public global::System.Type RequestType => typeof(TRequest); - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } + public global::System.Type Type => typeof(TRequest); /// /// Gets the primary content type accepted by the endpoint. @@ -428,6 +425,11 @@ internal sealed class {{AcceptsAttributeName}} : global::System.Attrib /// public string[] AdditionalContentTypes { get; } + /// + /// Gets a value indicating whether the request body is optional. + /// + public bool IsOptional { get; init; } + /// /// Initializes a new instance of the generic Accepts attribute class. /// @@ -435,8 +437,8 @@ internal sealed class {{AcceptsAttributeName}} : global::System.Attrib /// Additional content types accepted by the endpoint. public {{AcceptsAttributeName}}(string contentType = "application/json", params string[] additionalContentTypes) { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes; } } @@ -457,15 +459,15 @@ internal sealed class {{EndpointFilterAttributeName}} : global::System.Attribute /// /// Gets the CLR type of the endpoint filter. /// - public global::System.Type FilterType { get; } + public global::System.Type Type { get; } /// /// Initializes a new instance of the class. /// - /// The CLR type of the endpoint filter. - public {{EndpointFilterAttributeName}}(global::System.Type filterType) + /// The CLR type of the endpoint filter. + public {{EndpointFilterAttributeName}}(global::System.Type type) { - FilterType = filterType ?? throw new global::System.ArgumentNullException(nameof(filterType)); + Type = type; } } @@ -479,7 +481,7 @@ internal sealed class {{EndpointFilterAttributeName}} : global::System. /// /// Gets the CLR type of the endpoint filter. /// - public global::System.Type FilterType => typeof(TFilter); + public global::System.Type Type => typeof(TFilter); } """, Encoding.UTF8 @@ -496,10 +498,10 @@ namespace {{AttributesNamespace}}; [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribute { - /// - /// Gets the response type produced by the endpoint. - /// - public global::System.Type? ResponseType { get; init; } + /// + /// Gets the response type produced by the endpoint. + /// + public global::System.Type Type { get; } /// /// Gets the HTTP status code returned by the endpoint. @@ -519,11 +521,13 @@ internal sealed class {{ProducesResponseAttributeName}} : global::System.Attribu /// /// Initializes a new instance of the class. /// + /// The CLR type of the response body. /// The HTTP status code returned by the endpoint. /// The primary content type produced by the endpoint. /// Additional content types produced by the endpoint. - public {{ProducesResponseAttributeName}}(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) + public {{ProducesResponseAttributeName}}(global::System.Type type, int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) { + Type = type; StatusCode = statusCode; ContentType = contentType; AdditionalContentTypes = additionalContentTypes ?? []; @@ -540,7 +544,7 @@ internal sealed class {{ProducesResponseAttributeName}} : global::Sys /// /// Gets the response type produced by the endpoint. /// - public global::System.Type ResponseType => typeof(TResponse); + public global::System.Type Type => typeof(TResponse); /// /// Gets the HTTP status code returned by the endpoint. diff --git a/src/GeneratedEndpoints/Common/Constants.cs b/src/GeneratedEndpoints/Common/Constants.cs index 8ef77d1..d7cb149 100644 --- a/src/GeneratedEndpoints/Common/Constants.cs +++ b/src/GeneratedEndpoints/Common/Constants.cs @@ -1,5 +1,3 @@ -using System.ComponentModel; - namespace GeneratedEndpoints.Common; internal static partial class Constants @@ -7,8 +5,6 @@ internal static partial class Constants internal const string FallbackHttpMethod = "__FALLBACK__"; internal const string NameAttributeNamedParameter = "Name"; - internal const string ResponseTypeAttributeNamedParameter = "ResponseType"; - internal const string RequestTypeAttributeNamedParameter = "RequestType"; internal const string IsOptionalAttributeNamedParameter = "IsOptional"; internal const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute"; @@ -47,12 +43,6 @@ internal static partial class Constants internal const string SummaryAttributeName = "SummaryAttribute"; internal const string SummaryAttributeHint = $"{SummaryAttributeFullyQualifiedName}.gs.cs"; - internal const string DisplayNameAttributeName = nameof(DisplayNameAttribute); - internal const string DescriptionAttributeName = nameof(DescriptionAttribute); - internal const string AllowAnonymousAttributeName = "AllowAnonymousAttribute"; - internal const string TagsAttributeName = "TagsAttribute"; - internal const string ExcludeFromDescriptionAttributeName = "ExcludeFromDescriptionAttribute"; - internal const string EndpointFilterAttributeName = "EndpointFilterAttribute"; internal const string EndpointFilterAttributeHint = $"{EndpointFilterAttributeFullyQualifiedName}.gs.cs"; @@ -82,15 +72,7 @@ internal static partial class Constants internal const string AsyncSuffix = "Async"; internal const string ApplicationJsonContentType = "application/json"; internal const string GlobalPrefix = "global::"; - internal const string Dot = "."; - - internal static readonly string[] AttributesNamespaceParts = AttributesNamespace.Split('.'); - internal static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"]; - internal static readonly string[] AspNetCoreMvcNamespaceParts = ["Microsoft", "AspNetCore", "Mvc"]; - internal static readonly string[] AspNetCoreAuthorizationNamespaceParts = ["Microsoft", "AspNetCore", "Authorization"]; - internal static readonly string[] AspNetCoreRoutingNamespaceParts = ["Microsoft", "AspNetCore", "Routing"]; - internal static readonly string[] ExtensionsDependencyInjectionNamespaceParts = ["Microsoft", "Extensions", "DependencyInjection"]; - internal static readonly string[] ComponentModelNamespaceParts = ["System", "ComponentModel"]; + private const string BaseNamespace = "Microsoft.AspNetCore.Generated"; private const string AttributesNamespace = $"{BaseNamespace}.Attributes"; private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}"; diff --git a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs index 60c5bbc..3cf41fb 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfiguration.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfiguration.cs @@ -27,7 +27,5 @@ internal readonly record struct EndpointConfiguration public required bool WithRequestTimeout { get; init; } public required string? RequestTimeoutPolicyName { get; init; } public required int? Order { get; init; } - public required string? GroupIdentifier { get; init; } - public required string? GroupPattern { get; init; } - public required string? GroupName { get; init; } + public required EndpointGroup? Group { get; init; } } diff --git a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs index f986ac9..f8dccd5 100644 --- a/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs +++ b/src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs @@ -1,4 +1,3 @@ -using System.Runtime.CompilerServices; using Microsoft.CodeAnalysis; using static GeneratedEndpoints.Common.Constants; @@ -6,8 +5,6 @@ namespace GeneratedEndpoints.Common; internal static class EndpointConfigurationFactory { - private static readonly ConditionalWeakTable GeneratedAttributeKindCache = new(); - public static EndpointConfiguration Create(ISymbol symbol) { var attributes = symbol.GetAttributes(); @@ -48,7 +45,7 @@ public static EndpointConfiguration Create(ISymbol symbol) if (attributeClass is null) continue; - var attributeKind = GetGeneratedAttributeKind(attributeClass); + var attributeKind = attributeClass.OriginalDefinition.GetRequestHandlerAttributeKind(); switch (attributeKind) { case RequestHandlerAttributeKind.ShortCircuit: @@ -172,9 +169,14 @@ public static EndpointConfiguration Create(ISymbol symbol) WithRequestTimeout = withRequestTimeout ?? false, RequestTimeoutPolicyName = requestTimeoutPolicyName, Order = order, - GroupIdentifier = groupIdentifier, - GroupPattern = groupPattern, - GroupName = groupName, + Group = groupIdentifier is not null && groupPattern is not null + ? new EndpointGroup + { + Identifier = groupIdentifier, + Pattern = groupPattern, + Name = groupName, + } + : null, }; } @@ -198,16 +200,6 @@ public static EndpointConfiguration Create(ISymbol symbol) return StringBuilderPool.ToStringAndReturn(builder); } - private static RequestHandlerAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass) - { - var definition = attributeClass.OriginalDefinition; - var cacheEntry = GeneratedAttributeKindCache.GetValue( - definition, static def => new GeneratedAttributeKindCacheEntry(def.GetRequestHandlerAttributeKind()) - ); - - return cacheEntry.Kind; - } - private static EquatableImmutableArray? ToEquatableOrNull(List? values) { return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null; @@ -219,9 +211,11 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) requestType = attributeClass.TypeArguments[0] .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - else if (attribute.GetNamedTypeSymbol(RequestTypeAttributeNamedParameter) is { } requestTypeSymbol) - requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); else + requestType = attribute.GetConstructorTypeSymbol() + ?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + + if (requestType is null) return; var contentType = attribute.GetConstructorStringValue() ?? ApplicationJsonContentType; @@ -236,23 +230,29 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List? produces) { - string? responseType; + ProducesMetadata? producesMetadata; if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 }) - responseType = attributeClass.TypeArguments[0] + { + var responseType = attributeClass.TypeArguments[0] .ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - else if (attribute.GetNamedTypeSymbol(ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol) - responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var statusCode = attribute.GetConstructorIntValue(0) ?? 200; + var contentType = attribute.GetConstructorStringValue(1); + var additionalContentTypes = attribute.GetConstructorStringArray(2); + producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes); + } else - return; - - var statusCode = attribute.GetConstructorIntValue() ?? 200; - var contentType = attribute.GetConstructorStringValue(1); - var additionalContentTypes = attribute.GetConstructorStringArray(2); - - var producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes); + { + var responseType = attribute.GetConstructorTypeSymbol() + ?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat) + ?? ""; + var statusCode = attribute.GetConstructorIntValue(1) ?? 200; + var contentType = attribute.GetConstructorStringValue(2); + var additionalContentTypes = attribute.GetConstructorStringArray(3); + producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes); + } var producesList = produces ??= []; - producesList.Add(producesMetadata); + producesList.Add(producesMetadata.Value); } private static void TryAddEndpointFilter( @@ -291,9 +291,4 @@ private static void TryAddEndpointFilterType(ITypeSymbol? typeSymbol, ref List BindingSource.FromRoute, - "FromQueryAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromQuery, - "FromHeaderAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromHeader, - "FromBodyAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromBody, - "FromFormAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromForm, - "FromServicesAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromServices, - "FromKeyedServicesAttribute" when IsInNamespace(namespaceSymbol, ExtensionsDependencyInjectionNamespaceParts) => BindingSource.FromKeyedServices, - "AsParametersAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreHttpNamespaceParts) => BindingSource.AsParameters, - _ => BindingSource.None, - }; - } - private static string GetBindingSourceAttribute(BindingSource source, TypedConstant? typedKey, string? bindingName) { switch (source) diff --git a/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs b/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs deleted file mode 100644 index 36c1bca..0000000 --- a/src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs +++ /dev/null @@ -1,51 +0,0 @@ -using Microsoft.CodeAnalysis; -using static GeneratedEndpoints.Common.Constants; - -namespace GeneratedEndpoints.Common; - -/// Provides extension methods for working with named type symbols. -internal static class NamedTypeSymbolExtensions -{ - public static RequestHandlerAttributeKind GetRequestHandlerAttributeKind(this INamedTypeSymbol definition) - { - if (AttributeSymbolMatcher.IsAttribute(definition, DisplayNameAttributeName, ComponentModelNamespaceParts)) - return RequestHandlerAttributeKind.DisplayName; - - if (AttributeSymbolMatcher.IsAttribute(definition, DescriptionAttributeName, ComponentModelNamespaceParts)) - return RequestHandlerAttributeKind.Description; - - if (AttributeSymbolMatcher.IsAttribute(definition, AllowAnonymousAttributeName, AspNetCoreAuthorizationNamespaceParts)) - return RequestHandlerAttributeKind.AllowAnonymous; - - if (AttributeSymbolMatcher.IsAttribute(definition, TagsAttributeName, AspNetCoreHttpNamespaceParts)) - return RequestHandlerAttributeKind.Tags; - - if (AttributeSymbolMatcher.IsAttribute(definition, ExcludeFromDescriptionAttributeName, AspNetCoreRoutingNamespaceParts)) - return RequestHandlerAttributeKind.ExcludeFromDescription; - - if (!AttributeSymbolMatcher.IsInNamespace(definition.ContainingNamespace, AttributesNamespaceParts)) - return RequestHandlerAttributeKind.None; - - return definition.Name switch - { - ShortCircuitAttributeName => RequestHandlerAttributeKind.ShortCircuit, - DisableValidationAttributeName => RequestHandlerAttributeKind.DisableValidation, - DisableRequestTimeoutAttributeName => RequestHandlerAttributeKind.DisableRequestTimeout, - RequestTimeoutAttributeName => RequestHandlerAttributeKind.RequestTimeout, - OrderAttributeName => RequestHandlerAttributeKind.Order, - MapGroupAttributeName => RequestHandlerAttributeKind.MapGroup, - SummaryAttributeName => RequestHandlerAttributeKind.Summary, - AcceptsAttributeName => RequestHandlerAttributeKind.Accepts, - ProducesResponseAttributeName => RequestHandlerAttributeKind.ProducesResponse, - RequireAuthorizationAttributeName => RequestHandlerAttributeKind.RequireAuthorization, - RequireCorsAttributeName => RequestHandlerAttributeKind.RequireCors, - RequireHostAttributeName => RequestHandlerAttributeKind.RequireHost, - RequireRateLimitingAttributeName => RequestHandlerAttributeKind.RequireRateLimiting, - EndpointFilterAttributeName => RequestHandlerAttributeKind.EndpointFilter, - DisableAntiforgeryAttributeName => RequestHandlerAttributeKind.DisableAntiforgery, - ProducesProblemAttributeName => RequestHandlerAttributeKind.ProducesProblem, - ProducesValidationProblemAttributeName => RequestHandlerAttributeKind.ProducesValidationProblem, - _ => RequestHandlerAttributeKind.None, - }; - } -} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClass.cs b/src/GeneratedEndpoints/Common/RequestHandlerClass.cs index c5e868b..e3c8336 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerClass.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerClass.cs @@ -1,9 +1,42 @@ namespace GeneratedEndpoints.Common; -internal readonly record struct RequestHandlerClass( - string Name, - bool IsStatic, - bool HasConfigureMethod, - bool ConfigureMethodAcceptsServiceProvider, - EndpointConfiguration Configuration -); +internal readonly record struct RequestHandlerClass : IComparable, IComparable +{ + public required string Name { get; init; } + public required bool IsStatic { get; init; } + public required bool HasConfigureMethod { get; init; } + public required bool ConfigureMethodAcceptsServiceProvider { get; init; } + public required EndpointConfiguration Configuration { get; init; } + + public int CompareTo(RequestHandlerClass other) + { + return string.Compare(Name, other.Name, StringComparison.Ordinal); + } + + public int CompareTo(object? obj) + { + if (obj is null) + return 1; + return obj is RequestHandlerClass other ? CompareTo(other) : throw new ArgumentException($"Object must be of type {nameof(RequestHandlerClass)}"); + } + + public static bool operator <(RequestHandlerClass left, RequestHandlerClass right) + { + return left.CompareTo(right) < 0; + } + + public static bool operator >(RequestHandlerClass left, RequestHandlerClass right) + { + return left.CompareTo(right) > 0; + } + + public static bool operator <=(RequestHandlerClass left, RequestHandlerClass right) + { + return left.CompareTo(right) <= 0; + } + + public static bool operator >=(RequestHandlerClass left, RequestHandlerClass right) + { + return left.CompareTo(right) >= 0; + } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs deleted file mode 100644 index 3226132..0000000 --- a/src/GeneratedEndpoints/Common/RequestHandlerClassCacheEntry.cs +++ /dev/null @@ -1,157 +0,0 @@ -using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp.Syntax; -using static GeneratedEndpoints.Common.Constants; - -namespace GeneratedEndpoints.Common; - -internal sealed class RequestHandlerClassCacheEntry -{ - private readonly object _lock = new(); - private RequestHandlerClass _value; - private bool _initialized; - - public RequestHandlerClass GetOrCreate(INamedTypeSymbol classSymbol, CancellationToken cancellationToken) - { - if (_initialized) - return _value; - - lock (_lock) - { - if (_initialized) - return _value; - - cancellationToken.ThrowIfCancellationRequested(); - - var name = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); - var isStatic = classSymbol.IsStatic; - var configureMethodDetails = GetConfigureMethodDetails(classSymbol, cancellationToken); - - var classConfiguration = EndpointConfigurationFactory.Create(classSymbol); - - _value = new RequestHandlerClass(name, isStatic, configureMethodDetails.HasConfigureMethod, - configureMethodDetails.ConfigureMethodAcceptsServiceProvider, classConfiguration - ); - _initialized = true; - return _value; - } - } - - private static ConfigureMethodDetails GetConfigureMethodDetails(INamedTypeSymbol classSymbol, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - var hasConfigureMethod = false; - var acceptsServiceProvider = false; - foreach (var member in classSymbol.GetMembers(ConfigureMethodName)) - { - cancellationToken.ThrowIfCancellationRequested(); - - if (member is not IMethodSymbol methodSymbol) - continue; - - if (IsConfigureMethod(methodSymbol, out var methodAcceptsServiceProvider)) - { - hasConfigureMethod = true; - if (methodAcceptsServiceProvider) - { - acceptsServiceProvider = true; - break; - } - } - } - - return new ConfigureMethodDetails(hasConfigureMethod, acceptsServiceProvider); - } - - private static bool IsConfigureMethod(IMethodSymbol methodSymbol, out bool acceptsServiceProvider) - { - acceptsServiceProvider = false; - - if (!methodSymbol.IsStatic) - return false; - - if (methodSymbol.TypeParameters.Length != 1) - return false; - - if (methodSymbol.Parameters.Length is < 1 or > 2) - return false; - - var builderTypeParameter = methodSymbol.TypeParameters[0]; - var builderParameter = methodSymbol.Parameters[0]; - - if (!SymbolEqualityComparer.Default.Equals(builderParameter.Type, builderTypeParameter)) - return false; - - if (methodSymbol.Parameters.Length == 2) - { - var serviceProviderParameter = methodSymbol.Parameters[1]; - if (!IsServiceProviderParameter(serviceProviderParameter.Type)) - return false; - - acceptsServiceProvider = true; - } - - if (!methodSymbol.ReturnsVoid) - return false; - - if (!HasEndpointConventionBuilderConstraint(builderTypeParameter, methodSymbol)) - return false; - - return true; - } - - private static bool IsServiceProviderParameter(ITypeSymbol typeSymbol) - { - return MatchesServiceProvider(typeSymbol); - } - - private static bool HasEndpointConventionBuilderConstraint(ITypeParameterSymbol builderTypeParameter, IMethodSymbol methodSymbol) - { - var symbolMatches = builderTypeParameter.ConstraintTypes.Any(MatchesEndpointConventionBuilder); - if (symbolMatches) - return true; - - return methodSymbol.DeclaringSyntaxReferences - .Select(reference => reference.GetSyntax()) - .OfType() - .SelectMany(methodSyntax => methodSyntax.ConstraintClauses) - .Where(clause => string.Equals(clause.Name.Identifier.ValueText, builderTypeParameter.Name, StringComparison.Ordinal)) - .SelectMany(clause => clause.Constraints.OfType()) - .Any(constraint => IsEndpointConventionBuilderIdentifier(constraint.Type)); - } - - private static bool IsEndpointConventionBuilderIdentifier(TypeSyntax typeSyntax) - { - return typeSyntax switch - { - QualifiedNameSyntax qualified => IsEndpointConventionBuilderIdentifier(qualified.Right), - AliasQualifiedNameSyntax alias => IsEndpointConventionBuilderIdentifier(alias.Name), - SimpleNameSyntax simple => string.Equals(simple.Identifier.ValueText, "IEndpointConventionBuilder", StringComparison.Ordinal), - _ => false, - }; - } - - private static bool MatchesEndpointConventionBuilder(ITypeSymbol typeSymbol) - { - if (typeSymbol is not INamedTypeSymbol namedType) - return false; - - if (!string.Equals(namedType.Name, "IEndpointConventionBuilder", StringComparison.Ordinal)) - return false; - - var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? ""; - return string.Equals(containingNamespace, "Microsoft.AspNetCore.Builder", StringComparison.Ordinal); - } - - private static bool MatchesServiceProvider(ITypeSymbol typeSymbol) - { - if (typeSymbol is not INamedTypeSymbol namedType) - return false; - - if (!string.Equals(namedType.Name, "IServiceProvider", StringComparison.Ordinal)) - return false; - - var containingNamespace = namedType.ContainingNamespace?.ToDisplayString() ?? ""; - return string.Equals(containingNamespace, "System", StringComparison.Ordinal); - } -} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerClassHelper.cs b/src/GeneratedEndpoints/Common/RequestHandlerClassHelper.cs new file mode 100644 index 0000000..806a254 --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerClassHelper.cs @@ -0,0 +1,96 @@ +using Microsoft.CodeAnalysis; +using static GeneratedEndpoints.Common.Constants; + +namespace GeneratedEndpoints.Common; + +internal static class RequestHandlerClassHelper +{ + public static RequestHandlerClass? Create(IMethodSymbol methodSymbol, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + var classSymbol = methodSymbol.ContainingType; + if (classSymbol.TypeKind != TypeKind.Class) + return null; + + var name = classSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat); + var isStatic = classSymbol.IsStatic; + var configureMethodDetails = GetConfigureMethodDetails(classSymbol, cancellationToken); + var classConfiguration = EndpointConfigurationFactory.Create(classSymbol); + + var requestHandlerClass = new RequestHandlerClass + { + Name = name, + IsStatic = isStatic, + HasConfigureMethod = configureMethodDetails.HasConfigureMethod, + ConfigureMethodAcceptsServiceProvider = configureMethodDetails.ConfigureMethodAcceptsServiceProvider, + Configuration = classConfiguration, + }; + + return requestHandlerClass; + } + + private static ConfigureMethodDetails GetConfigureMethodDetails(INamedTypeSymbol classSymbol, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + var hasConfigureMethod = false; + var acceptsServiceProvider = false; + foreach (var member in classSymbol.GetMembers(ConfigureMethodName)) + { + cancellationToken.ThrowIfCancellationRequested(); + + if (member is not IMethodSymbol methodSymbol) + continue; + + if (IsConfigureMethod(methodSymbol, out var methodAcceptsServiceProvider)) + { + hasConfigureMethod = true; + if (methodAcceptsServiceProvider) + { + acceptsServiceProvider = true; + break; + } + } + } + + return new ConfigureMethodDetails(hasConfigureMethod, acceptsServiceProvider); + } + + private static bool IsConfigureMethod(IMethodSymbol methodSymbol, out bool acceptsServiceProvider) + { + acceptsServiceProvider = false; + + if (!methodSymbol.IsStatic) + return false; + + if (methodSymbol.TypeParameters.Length != 1) + return false; + + if (methodSymbol.Parameters.Length is < 1 or > 2) + return false; + + var builderTypeParameter = methodSymbol.TypeParameters[0]; + var builderParameter = methodSymbol.Parameters[0]; + + if (!SymbolEqualityComparer.Default.Equals(builderParameter.Type, builderTypeParameter)) + return false; + + if (methodSymbol.Parameters.Length == 2) + { + var serviceProviderParameter = methodSymbol.Parameters[1]; + if (!serviceProviderParameter.Type.IsIServiceProvider()) + return false; + + acceptsServiceProvider = true; + } + + if (!methodSymbol.ReturnsVoid) + return false; + + if (!builderTypeParameter.ConstraintTypes.Any(x => x.IsIEndpointConventionBuilder())) + return false; + + return true; + } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs index 54e2958..52669c4 100644 --- a/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs +++ b/src/GeneratedEndpoints/Common/RequestHandlerMethod.cs @@ -1,8 +1,41 @@ namespace GeneratedEndpoints.Common; -internal readonly record struct RequestHandlerMethod( - string Name, - bool IsStatic, - EquatableImmutableArray Parameters, - EndpointConfiguration Configuration -); +internal readonly record struct RequestHandlerMethod : IComparable, IComparable +{ + public required string Name { get; init; } + public required bool IsStatic { get; init; } + public required EquatableImmutableArray Parameters { get; init; } + public required EndpointConfiguration Configuration { get; init; } + + public int CompareTo(RequestHandlerMethod other) + { + return string.Compare(Name, other.Name, StringComparison.Ordinal); + } + + public int CompareTo(object? obj) + { + if (obj is null) + return 1; + return obj is RequestHandlerMethod other ? CompareTo(other) : throw new ArgumentException($"Object must be of type {nameof(RequestHandlerMethod)}"); + } + + public static bool operator <(RequestHandlerMethod left, RequestHandlerMethod right) + { + return left.CompareTo(right) < 0; + } + + public static bool operator >(RequestHandlerMethod left, RequestHandlerMethod right) + { + return left.CompareTo(right) > 0; + } + + public static bool operator <=(RequestHandlerMethod left, RequestHandlerMethod right) + { + return left.CompareTo(right) <= 0; + } + + public static bool operator >=(RequestHandlerMethod left, RequestHandlerMethod right) + { + return left.CompareTo(right) >= 0; + } +} diff --git a/src/GeneratedEndpoints/Common/RequestHandlerMethodHelper.cs b/src/GeneratedEndpoints/Common/RequestHandlerMethodHelper.cs new file mode 100644 index 0000000..72193dc --- /dev/null +++ b/src/GeneratedEndpoints/Common/RequestHandlerMethodHelper.cs @@ -0,0 +1,25 @@ +using Microsoft.CodeAnalysis; + +namespace GeneratedEndpoints.Common; + +internal static class RequestHandlerMethodHelper +{ + public static RequestHandlerMethod Create(IMethodSymbol methodSymbol, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested(); + + var name = methodSymbol.Name; + var isStatic = methodSymbol.IsStatic; + var parameters = methodSymbol.GetParameters(cancellationToken); + var configuration = EndpointConfigurationFactory.Create(methodSymbol); + var requestHandlerMethod = new RequestHandlerMethod + { + Name = name, + IsStatic = isStatic, + Parameters = parameters, + Configuration = configuration, + }; + + return requestHandlerMethod; + } +} diff --git a/src/GeneratedEndpoints/Common/StringExtensions.cs b/src/GeneratedEndpoints/Common/StringExtensions.cs index 5867811..1ca906a 100644 --- a/src/GeneratedEndpoints/Common/StringExtensions.cs +++ b/src/GeneratedEndpoints/Common/StringExtensions.cs @@ -1,4 +1,5 @@ using System.Globalization; +using static GeneratedEndpoints.Common.Constants; namespace GeneratedEndpoints.Common; @@ -75,4 +76,12 @@ public static string NormalizeOrDefaultString(this string? value, string default { return string.IsNullOrWhiteSpace(value) ? defaultValue : value!.Trim(); } + + public static string RemoveAsyncSuffix(this string methodName) + { + if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) + return methodName[..^AsyncSuffix.Length]; + + return methodName; + } } diff --git a/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs b/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs index 0b69f45..c7dcbe4 100644 --- a/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs +++ b/src/GeneratedEndpoints/Common/TypeSymbolExtensions.cs @@ -4,12 +4,474 @@ namespace GeneratedEndpoints.Common; internal static class TypeSymbolExtensions { + public static RequestHandlerAttributeKind GetRequestHandlerAttributeKind(this ITypeSymbol symbol) + { + var definition = symbol.OriginalDefinition; + return definition switch + { + { + MetadataName: "DisplayNameAttribute", + ContainingNamespace: + { + Name: "ComponentModel", + ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true }, + }, + } => RequestHandlerAttributeKind.DisplayName, + { + MetadataName: "DescriptionAttribute", + ContainingNamespace: + { + Name: "ComponentModel", + ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true }, + }, + } => RequestHandlerAttributeKind.Description, + { + MetadataName: "AllowAnonymousAttribute", + ContainingNamespace: + { + Name: "Authorization", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + } => RequestHandlerAttributeKind.AllowAnonymous, + { + MetadataName: "TagsAttribute", + ContainingNamespace: + { + Name: "Http", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + } => RequestHandlerAttributeKind.Tags, + { + MetadataName: "ExcludeFromDescriptionAttribute", + ContainingNamespace: + { + Name: "Routing", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + } => RequestHandlerAttributeKind.ExcludeFromDescription, + { + MetadataName: "ExcludeFromDescriptionAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ExcludeFromDescription, + { + MetadataName: "ShortCircuitAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ShortCircuit, + { + MetadataName: "DisableValidationAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.DisableValidation, + { + MetadataName: "DisableRequestTimeoutAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.DisableRequestTimeout, + { + MetadataName: "RequestTimeoutAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.RequestTimeout, + { + MetadataName: "OrderAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.Order, + { + MetadataName: "MapGroupAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.MapGroup, + { + MetadataName: "SummaryAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.Summary, + { + MetadataName: "AcceptsAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.Accepts, + { + MetadataName: "AcceptsAttribute`1", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.Accepts, + { + MetadataName: "ProducesResponseAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ProducesResponse, + { + MetadataName: "ProducesResponseAttribute`1", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ProducesResponse, + { + MetadataName: "RequireAuthorizationAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.RequireAuthorization, + { + MetadataName: "RequireCorsAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.RequireCors, + { + MetadataName: "RequireHostAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.RequireHost, + { + MetadataName: "RequireRateLimitingAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.RequireRateLimiting, + { + MetadataName: "EndpointFilterAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.EndpointFilter, + { + MetadataName: "EndpointFilterAttribute`1", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.EndpointFilter, + { + MetadataName: "DisableAntiforgeryAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.DisableAntiforgery, + { + MetadataName: "ProducesProblemAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ProducesProblem, + { + MetadataName: "ProducesValidationProblemAttribute", + ContainingNamespace: + { + Name: "Attributes", + ContainingNamespace: + { + Name: "Generated", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }, + } => RequestHandlerAttributeKind.ProducesValidationProblem, + _ => RequestHandlerAttributeKind.None, + }; + } + + public static BindingSource GetBindingSource(this ITypeSymbol symbol) + { + var definition = symbol.OriginalDefinition; + return definition switch + { + { + MetadataName: "FromRouteAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromRoute, + { + MetadataName: "FromQueryAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromQuery, + { + MetadataName: "FromHeaderAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromHeader, + { + MetadataName: "FromBodyAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromBody, + { + MetadataName: "FromFormAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromForm, + { + MetadataName: "FromServicesAttribute", + ContainingNamespace: + { + Name: "Mvc", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromServices, + { + MetadataName: "FromKeyedServicesAttribute", + ContainingNamespace: + { + Name: "DependencyInjection", + ContainingNamespace: + { + Name: "Extensions", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.FromKeyedServices, + { + MetadataName: "AsParametersAttribute", + ContainingNamespace: + { + Name: "Http", + ContainingNamespace: + { + Name: "AspNetCore", + ContainingNamespace: + { + Name: "Microsoft", + ContainingNamespace.IsGlobalNamespace: true, + }, + }, + }, + } => BindingSource.AsParameters, + _ => BindingSource.None, + }; + } + + public static bool IsIEndpointConventionBuilder(this ITypeSymbol symbol) + { + return symbol is + { + MetadataName: "IEndpointConventionBuilder", + ContainingNamespace: + { + Name: "Builder", + ContainingNamespace: { Name: "AspNetCore", ContainingNamespace: { Name: "Microsoft", ContainingNamespace.IsGlobalNamespace: true } }, + }, + }; + } + + public static bool IsIServiceProvider(this ITypeSymbol symbol) + { + return symbol is + { + MetadataName: "IServiceProvider", + ContainingNamespace: + { + Name: "System", ContainingNamespace.IsGlobalNamespace: true, + }, + }; + } + public static bool IsAwaitable(this ITypeSymbol symbol) { return symbol switch { - INamedTypeSymbol - { + { MetadataName: "ValueTask`1", ContainingNamespace: { @@ -17,7 +479,7 @@ public static bool IsAwaitable(this ITypeSymbol symbol) ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, }, } - or INamedTypeSymbol + or { MetadataName: "Task`1", ContainingNamespace: @@ -26,7 +488,7 @@ or INamedTypeSymbol ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, }, } - or INamedTypeSymbol + or { MetadataName: "ValueTask", ContainingNamespace: @@ -35,7 +497,7 @@ or INamedTypeSymbol ContainingNamespace: { Name: "Threading", ContainingNamespace: { Name: "System", ContainingNamespace.IsGlobalNamespace: true } }, }, } - or INamedTypeSymbol + or { MetadataName: "Task", ContainingNamespace: diff --git a/src/GeneratedEndpoints/GeneratedEndpoints.csproj b/src/GeneratedEndpoints/GeneratedEndpoints.csproj index ca27350..efddba0 100644 --- a/src/GeneratedEndpoints/GeneratedEndpoints.csproj +++ b/src/GeneratedEndpoints/GeneratedEndpoints.csproj @@ -36,9 +36,9 @@ GeneratedEndpoints - 10.0.1 - 10.0.1.0 - 10.0.1.0 + 10.0.2 + 10.0.2.0 + 10.0.2.0 en-US false diff --git a/src/GeneratedEndpoints/MinimalApiGenerator.cs b/src/GeneratedEndpoints/MinimalApiGenerator.cs index 00181d4..873acfd 100644 --- a/src/GeneratedEndpoints/MinimalApiGenerator.cs +++ b/src/GeneratedEndpoints/MinimalApiGenerator.cs @@ -1,5 +1,4 @@ using System.Collections.Immutable; -using System.Runtime.CompilerServices; using GeneratedEndpoints.Common; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp.Syntax; @@ -14,48 +13,15 @@ namespace GeneratedEndpoints; [Generator] public sealed class MinimalApiGenerator : IIncrementalGenerator { - private static readonly ConditionalWeakTable RequestHandlerClassCache = new(); - public void Initialize(IncrementalGeneratorInitializationContext context) { context.RegisterPostInitializationOutput(RegisterAttributes); - var requestHandlerProviders = ImmutableArray.CreateBuilder>>(HttpAttributeDefinitions.Length); - - for (var index = 0; index < HttpAttributeDefinitions.Length; index++) - { - var definition = HttpAttributeDefinitions[index]; - var handlers = context.SyntaxProvider - .ForAttributeWithMetadataName(definition.FullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) - .WhereNotNull() - .Collect(); - - requestHandlerProviders.Add(handlers); - } - - var requestHandlers = CombineRequestHandlers(requestHandlerProviders.MoveToImmutable()) - .Select((x, _) => x.ToEquatableImmutableArray()); + var requestHandlers = GetRequestHandlers(context); context.RegisterSourceOutput(requestHandlers, GenerateSource); } - private static IncrementalValueProvider> CombineRequestHandlers( - ImmutableArray>> handlerProviders - ) - { - if (handlerProviders.IsDefaultOrEmpty) - throw new InvalidOperationException("No HTTP attribute definitions were provided."); - - var combined = handlerProviders[0]; - for (var i = 1; i < handlerProviders.Length; i++) - { - combined = combined.Combine(handlerProviders[i]) - .Select(static (x, _) => x.Left.AddRange(x.Right)); - } - - return combined; - } - private static void RegisterAttributes(IncrementalGeneratorPostInitializationContext context) { foreach (var definition in HttpAttributeDefinitions) @@ -80,6 +46,44 @@ private static void RegisterAttributes(IncrementalGeneratorPostInitializationCon context.AddSource(SummaryAttributeHint, SummaryAttributeSourceText); } + private static IncrementalValueProvider> GetRequestHandlers(IncrementalGeneratorInitializationContext context) + { + var list = new List>>(HttpAttributeDefinitions.Length); + + for (var index = 0; index < HttpAttributeDefinitions.Length; index++) + { + var definition = HttpAttributeDefinitions[index]; + var handlers = context.SyntaxProvider + .ForAttributeWithMetadataName(definition.FullyQualifiedName, RequestHandlerFilter, RequestHandlerTransform) + .WhereNotNull() + .Collect(); + + list.Add(handlers); + } + + if (list.Count == 0) + throw new InvalidOperationException("No HTTP attribute definitions were provided."); + + var combined = list[0]; + for (var i = 1; i < list.Count; i++) + { + combined = combined.Combine(list[i]) + .Select(static (x, ct) => + { + ct.ThrowIfCancellationRequested(); + return x.Left.AddRange(x.Right); + } + ); + } + + return combined.Select((x, ct) => + { + ct.ThrowIfCancellationRequested(); + return x.ToEquatableImmutableArray(); + } + ); + } + private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToken cancellationToken) { cancellationToken.ThrowIfCancellationRequested(); @@ -95,11 +99,11 @@ private static bool RequestHandlerFilter(SyntaxNode syntaxNode, CancellationToke return null; var attribute = context.Attributes[0]; - var requestHandlerClass = GetRequestHandlerClass(methodSymbol, cancellationToken); + var requestHandlerClass = RequestHandlerClassHelper.Create(methodSymbol, cancellationToken); if (requestHandlerClass is null) return null; - var requestHandlerMethod = GetRequestHandlerMethod(methodSymbol, cancellationToken); + var requestHandlerMethod = RequestHandlerMethodHelper.Create(methodSymbol, cancellationToken); var (httpMethod, pattern, name) = GetRequestHandlerAttribute(methodSymbol, attribute, cancellationToken); @@ -127,69 +131,32 @@ CancellationToken cancellationToken var httpMethod = HttpAttributeDefinitionsByName.TryGetValue(attributeName, out var definition) ? definition.Verb : ""; var pattern = attribute.GetConstructorStringValue() ?? ""; var name = attribute.GetNamedStringValue(NameAttributeNamedParameter); - name ??= RemoveAsyncSuffix(methodSymbol.Name); + name ??= methodSymbol.Name.RemoveAsyncSuffix(); return (httpMethod, pattern, name); } - private static string RemoveAsyncSuffix(string methodName) - { - if (methodName.EndsWith(AsyncSuffix, StringComparison.OrdinalIgnoreCase) && methodName.Length > AsyncSuffix.Length) - return methodName[..^AsyncSuffix.Length]; - - return methodName; - } - - private static RequestHandlerClass? GetRequestHandlerClass(IMethodSymbol methodSymbol, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - var classSymbol = methodSymbol.ContainingType; - if (classSymbol.TypeKind != TypeKind.Class) - return null; - - var cacheEntry = RequestHandlerClassCache.GetValue(classSymbol, static _ => new RequestHandlerClassCacheEntry()); - var requestHandlerClass = cacheEntry.GetOrCreate(classSymbol, cancellationToken); - - return requestHandlerClass; - } - - private static RequestHandlerMethod GetRequestHandlerMethod(IMethodSymbol methodSymbol, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - var name = methodSymbol.Name; - var isStatic = methodSymbol.IsStatic; - var parameters = methodSymbol.GetParameters(cancellationToken); - var configuration = EndpointConfigurationFactory.Create(methodSymbol); - var requestHandlerMethod = new RequestHandlerMethod(name, isStatic, parameters, configuration); - - return requestHandlerMethod; - } - private static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) { context.CancellationToken.ThrowIfCancellationRequested(); - var normalized = NormalizeRequestHandlers(requestHandlers); - - AddEndpointHandlersGenerator.GenerateSource(context, normalized); - UseEndpointHandlersGenerator.GenerateSource(context, normalized); - } - - private static EquatableImmutableArray NormalizeRequestHandlers(EquatableImmutableArray requestHandlers) - { - if (requestHandlers.Count <= 1) - return requestHandlers; - - requestHandlers.SortInPlace(RequestHandlerComparer.Instance); ResolveEndpointNameCollisions(requestHandlers); - return requestHandlers; + var grouped = requestHandlers.GroupBy(x => x.Class) + .OrderBy(x => x.Key) + .ToImmutableSortedDictionary(x => x.Key, x => x.OrderBy(y => y.Method) + .ToImmutableArray() + ); + + AddEndpointHandlersGenerator.GenerateSource(context, grouped); + UseEndpointHandlersGenerator.GenerateSource(context, grouped); } private static void ResolveEndpointNameCollisions(EquatableImmutableArray requestHandlers) { + if (requestHandlers.Count <= 1) + return; + var raw = requestHandlers.AsArray(); if (raw is null) return; diff --git a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs index 93ba94e..8711e7f 100644 --- a/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs +++ b/src/GeneratedEndpoints/UseEndpointHandlersGenerator.cs @@ -13,11 +13,15 @@ namespace GeneratedEndpoints; internal static class UseEndpointHandlersGenerator { - public static void GenerateSource(SourceProductionContext context, EquatableImmutableArray requestHandlers) + public static void GenerateSource(SourceProductionContext context, ImmutableSortedDictionary> grouped) { context.CancellationToken.ThrowIfCancellationRequested(); - var source = GetUseEndpointHandlersStringBuilder(requestHandlers); + var requestHandlers = grouped.Values + .SelectMany(x => x) + .ToImmutableList(); + + var source = new StringBuilder(); source.AppendLine(FileHeader); source.AppendLine(); @@ -26,7 +30,7 @@ public static void GenerateSource(SourceProductionContext context, EquatableImmu source.AppendLine("using Microsoft.AspNetCore.Http;"); source.AppendLine("using Microsoft.AspNetCore.Mvc;"); source.AppendLine("using Microsoft.AspNetCore.Routing;"); - if (HasRateLimitedHandlers(requestHandlers)) + if (AddUsingRateLimiting(requestHandlers)) source.AppendLine("using Microsoft.AspNetCore.RateLimiting;"); source.AppendLine("using Microsoft.Extensions.DependencyInjection;"); source.AppendLine(); @@ -49,17 +53,23 @@ public static void GenerateSource(SourceProductionContext context, EquatableImmu source.AppendLine(" {"); - var groupedClasses = GetClassesWithMapGroups(requestHandlers); + var groupedClasses = GetClassesWithGroups(requestHandlers); for (var index = 0; index < groupedClasses.Count; index++) { var groupedClass = groupedClasses[index]; + var configuration = groupedClass.Configuration; + if (!configuration.Group.HasValue) + continue; + + var group = configuration.Group.Value; + source.Append(" var "); - source.Append(groupedClass.Configuration.GroupIdentifier); + source.Append(group.Identifier); source.Append(" = builder.MapGroup("); - source.Append(groupedClass.Configuration.GroupPattern!.ToStringLiteral()); + source.Append(group.Pattern.ToStringLiteral()); source.Append(')'); - AppendEndpointConfiguration(source, " ", groupedClass.Configuration); + AppendEndpointConfiguration(source, " ", configuration); source.AppendLine(";"); } @@ -87,7 +97,7 @@ public static void GenerateSource(SourceProductionContext context, EquatableImmu context.AddSource(UseEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8)); } - private static bool HasRateLimitedHandlers(EquatableImmutableArray requestHandlers) + private static bool AddUsingRateLimiting(ImmutableList requestHandlers) { for (var index = 0; index < requestHandlers.Count; index++) { @@ -99,25 +109,29 @@ private static bool HasRateLimitedHandlers(EquatableImmutableArray GetClassesWithMapGroups(EquatableImmutableArray requestHandlers) + private static List GetClassesWithGroups(ImmutableList requestHandlers) { - var groupedClasses = new List(); if (requestHandlers.Count == 0) - return groupedClasses; + return []; - var seen = new HashSet(StringComparer.Ordinal); + HashSet? seen = null; + List? groupedClasses = null; for (var index = 0; index < requestHandlers.Count; index++) { var handler = requestHandlers[index]; var handlerClass = handler.Class; - if (handlerClass.Configuration.GroupPattern is null) + if (!handlerClass.Configuration.Group.HasValue) continue; - if (seen.Add(handlerClass.Name)) - groupedClasses.Add(handlerClass); + seen ??= new HashSet(StringComparer.Ordinal); + if (!seen.Add(handlerClass.Name)) + continue; + + groupedClasses ??= []; + groupedClasses.Add(handlerClass); } - return groupedClasses; + return groupedClasses ?? []; } private static void GenerateMapRequestHandler(StringBuilder source, RequestHandler requestHandler) @@ -126,7 +140,7 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl var configureAcceptsServiceProvider = requestHandler.Class.ConfigureMethodAcceptsServiceProvider; var indent = wrapWithConfigure ? " " : " "; var continuationIndent = indent + " "; - var routeBuilderIdentifier = requestHandler.Class.Configuration.GroupIdentifier ?? "builder"; + var routeBuilderIdentifier = requestHandler.Class.Configuration.Group?.Identifier ?? "builder"; if (wrapWithConfigure) { @@ -200,7 +214,7 @@ private static void GenerateMapRequestHandler(StringBuilder source, RequestHandl source.Append(')'); var configuration = requestHandler.Method.Configuration; - if (requestHandler.Class.Configuration.GroupPattern is null) + if (!requestHandler.Class.Configuration.Group.HasValue) configuration = MergeEndpointConfigurations(requestHandler.Class.Configuration, configuration); if (!string.IsNullOrEmpty(requestHandler.Name)) @@ -261,12 +275,12 @@ private static void AppendEndpointConfiguration(StringBuilder source, string ind source.Append(')'); } - if (!string.IsNullOrEmpty(configuration.GroupName)) + if (configuration.Group is { Name.Length: > 0 }) { source.AppendLine(); source.Append(indent); source.Append(".WithGroupName("); - source.Append(configuration.GroupName.ToStringLiteral()); + source.Append(configuration.Group.Value.Name.ToStringLiteral()); source.Append(')'); } @@ -465,51 +479,32 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu var displayName = methodConfiguration.DisplayName ?? classConfiguration.DisplayName; var summary = methodConfiguration.Summary ?? classConfiguration.Summary; var description = methodConfiguration.Description ?? classConfiguration.Description; - var tags = MergeDistinctStrings(methodConfiguration.Tags, classConfiguration.Tags); + var excludeFromDescription = methodConfiguration.ExcludeFromDescription || classConfiguration.ExcludeFromDescription; + var accepts = ConcatEquatable(methodConfiguration.Accepts, classConfiguration.Accepts); var produces = ConcatEquatable(methodConfiguration.Produces, classConfiguration.Produces); var producesProblem = ConcatEquatable(methodConfiguration.ProducesProblem, classConfiguration.ProducesProblem); var producesValidationProblem = ConcatEquatable(methodConfiguration.ProducesValidationProblem, classConfiguration.ProducesValidationProblem); - var excludeFromDescription = methodConfiguration.ExcludeFromDescription || classConfiguration.ExcludeFromDescription; - - var authorizationPolicies = MergeDistinctStrings(methodConfiguration.AuthorizationPolicies, classConfiguration.AuthorizationPolicies); - var requiredHosts = MergeDistinctStrings(methodConfiguration.RequiredHosts, classConfiguration.RequiredHosts); - var endpointFilterTypes = ConcatEquatable(methodConfiguration.EndpointFilterTypes, classConfiguration.EndpointFilterTypes); - - var requireAuthorization = methodConfiguration.RequireAuthorization || classConfiguration.RequireAuthorization; - var disableAntiforgery = methodConfiguration.DisableAntiforgery || classConfiguration.DisableAntiforgery; - var allowAnonymous = methodConfiguration.AllowAnonymous || classConfiguration.AllowAnonymous; - - var requireCors = methodConfiguration.RequireCors || classConfiguration.RequireCors; - var corsPolicyName = methodConfiguration.CorsPolicyName ?? classConfiguration.CorsPolicyName; - - var requireRateLimiting = methodConfiguration.RequireRateLimiting || classConfiguration.RequireRateLimiting; - var rateLimitingPolicyName = methodConfiguration.RateLimitingPolicyName ?? classConfiguration.RateLimitingPolicyName; - var shortCircuit = methodConfiguration.ShortCircuit || classConfiguration.ShortCircuit; + var order = methodConfiguration.Order ?? classConfiguration.Order; + var disableAntiforgery = methodConfiguration.DisableAntiforgery || classConfiguration.DisableAntiforgery; var disableValidation = methodConfiguration.DisableValidation || classConfiguration.DisableValidation; - var disableRequestTimeout = methodConfiguration.DisableRequestTimeout || classConfiguration.DisableRequestTimeout; - var withRequestTimeout = methodConfiguration.WithRequestTimeout || classConfiguration.WithRequestTimeout; + var requiredHosts = MergeDistinctStrings(methodConfiguration.RequiredHosts, classConfiguration.RequiredHosts); + var endpointFilterTypes = ConcatEquatable(methodConfiguration.EndpointFilterTypes, classConfiguration.EndpointFilterTypes); - var groupIdentifier = methodConfiguration.GroupIdentifier ?? classConfiguration.GroupIdentifier; - var groupPattern = methodConfiguration.GroupPattern ?? classConfiguration.GroupPattern; - var groupName = methodConfiguration.GroupName ?? classConfiguration.GroupName; + var (allowAnonymous, requireAuthorization) = ResolveAuthorization(methodConfiguration, classConfiguration); + var authorizationPolicies = MergeDistinctStrings(methodConfiguration.AuthorizationPolicies, classConfiguration.AuthorizationPolicies); - string? requestTimeoutPolicyName = null; - if (methodConfiguration.WithRequestTimeout) - requestTimeoutPolicyName = methodConfiguration.RequestTimeoutPolicyName; - else if (classConfiguration.WithRequestTimeout) - requestTimeoutPolicyName = classConfiguration.RequestTimeoutPolicyName; + var (requireCors, corsPolicyName) = ResolveCors(methodConfiguration, classConfiguration); + var (requireRateLimiting, rateLimitingPolicyName) = ResolveRateLimiting(methodConfiguration, classConfiguration); - if (disableRequestTimeout) - { - withRequestTimeout = false; - requestTimeoutPolicyName = null; - } + var (disableRequestTimeout, withRequestTimeout, requestTimeoutPolicyName) = ResolveRequestTimeout(methodConfiguration, classConfiguration); - var order = methodConfiguration.Order ?? classConfiguration.Order; + var groupIdentifier = methodConfiguration.Group?.Identifier ?? classConfiguration.Group?.Identifier; + var groupPattern = methodConfiguration.Group?.Pattern ?? classConfiguration.Group?.Pattern; + var groupName = methodConfiguration.Group?.Name ?? classConfiguration.Group?.Name; return new EndpointConfiguration { @@ -538,12 +533,109 @@ private static EndpointConfiguration MergeEndpointConfigurations(EndpointConfigu WithRequestTimeout = withRequestTimeout, RequestTimeoutPolicyName = requestTimeoutPolicyName, Order = order, - GroupIdentifier = groupIdentifier, - GroupPattern = groupPattern, - GroupName = groupName, + Group = groupIdentifier is not null && groupPattern is not null + ? new EndpointGroup + { + Identifier = groupIdentifier, + Pattern = groupPattern, + Name = groupName, + } + : null, }; } + private static (bool AllowAnonymous, bool RequireAuthorization) ResolveAuthorization( + EndpointConfiguration methodConfiguration, + EndpointConfiguration classConfiguration + ) + { + var methodReq = methodConfiguration.RequireAuthorization; + var methodAnon = !methodReq && methodConfiguration.AllowAnonymous; + + var classReq = classConfiguration.RequireAuthorization; + var classAnon = !classReq && classConfiguration.AllowAnonymous; + + var methodDeclares = methodConfiguration.AllowAnonymous || methodConfiguration.RequireAuthorization; + + if (methodDeclares) + { + // Method directive wins + if (methodReq) + return (AllowAnonymous: false, RequireAuthorization: true); + + if (methodAnon) + return (AllowAnonymous: true, RequireAuthorization: false); + + return (false, false); + } + + if (classReq) + return (AllowAnonymous: false, RequireAuthorization: true); + + if (classAnon) + return (AllowAnonymous: true, RequireAuthorization: false); + + return (AllowAnonymous: false, RequireAuthorization: false); + } + + private static (bool DisableRequestTimeout, bool WithRequestTimeout, string? RequestTimeoutPolicyName) ResolveRequestTimeout( + EndpointConfiguration methodConfiguration, + EndpointConfiguration classConfiguration + ) + { + var methodWith = methodConfiguration.WithRequestTimeout; + var methodDisable = !methodWith && methodConfiguration.DisableRequestTimeout; + + var classWith = classConfiguration.WithRequestTimeout; + var classDisable = !classWith && classConfiguration.DisableRequestTimeout; + + var methodDeclares = methodConfiguration.DisableRequestTimeout || methodConfiguration.WithRequestTimeout; + + if (methodDeclares) + { + if (methodWith) + return (DisableRequestTimeout: false, WithRequestTimeout: true, methodConfiguration.RequestTimeoutPolicyName); + + if (methodDisable) + return (DisableRequestTimeout: true, WithRequestTimeout: false, null); + + return (false, false, null); + } + + if (classWith) + return (DisableRequestTimeout: false, WithRequestTimeout: true, classConfiguration.RequestTimeoutPolicyName); + + if (classDisable) + return (DisableRequestTimeout: true, WithRequestTimeout: false, null); + + return (DisableRequestTimeout: false, WithRequestTimeout: false, null); + } + + private static (bool RequireCors, string? CorsPolicyName) ResolveCors(EndpointConfiguration methodConfiguration, EndpointConfiguration classConfiguration) + { + if (methodConfiguration.RequireCors) + return (RequireCors: true, methodConfiguration.CorsPolicyName); + + if (classConfiguration.RequireCors) + return (RequireCors: true, classConfiguration.CorsPolicyName); + + return (RequireCors: false, CorsPolicyName: null); + } + + private static (bool RequireRateLimiting, string? RateLimitingPolicyName) ResolveRateLimiting( + EndpointConfiguration methodConfiguration, + EndpointConfiguration classConfiguration + ) + { + if (methodConfiguration.RequireRateLimiting) + return (RequireRateLimiting: true, methodConfiguration.RateLimitingPolicyName); + + if (classConfiguration.RequireRateLimiting) + return (RequireRateLimiting: true, classConfiguration.RateLimitingPolicyName); + + return (RequireRateLimiting: false, RateLimitingPolicyName: null); + } + private static EquatableImmutableArray? MergeDistinctStrings(EquatableImmutableArray? first, EquatableImmutableArray? second) { if (first is not { Count: > 0 }) @@ -611,21 +703,6 @@ private static EquatableImmutableArray MergeUnion(EquatableImmutableArra }; } - private static StringBuilder GetUseEndpointHandlersStringBuilder(EquatableImmutableArray requestHandlers) - { - const int baseSize = 4096; - const int perHandler = 512; - - var handlerCount = Math.Max(requestHandlers.Count, 0); - var estimate = baseSize + (long)perHandler * handlerCount; - estimate = (long)(estimate * 1.10); - - if (estimate > int.MaxValue) - estimate = int.MaxValue; - - return StringBuilderPool.Get((int)Math.Max(baseSize, estimate)); - } - private static void AppendAdditionalContentTypes(StringBuilder source, EquatableImmutableArray? additionalContentTypes) { if (additionalContentTypes is not { Count: > 0 }) diff --git a/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs b/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs index 31fd2cc..313239a 100644 --- a/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs +++ b/tests/GeneratedEndpoints.Tests.Lab/GetUserEndpoint.cs @@ -22,9 +22,9 @@ internal sealed class GetUserEndpoint(IServiceProvider serviceProvider) { [Tags("Featured")] [AllowAnonymous] - [Accepts("application/json", "application/xml", RequestType = typeof(GetUserRequest))] + [Accepts(typeof(GetUserRequest), "application/json", "application/xml")] [Accepts("application/json", "application/xml", IsOptional = true)] - [ProducesResponse(StatusCodes.Status200OK, "application/json", ResponseType = typeof(UserProfile))] + [ProducesResponse(typeof(UserProfile), StatusCodes.Status200OK, "application/json")] [ProducesResponse(StatusCodes.Status202Accepted, "application/json")] [ProducesProblem(StatusCodes.Status500InternalServerError, "application/problem+json")] [ProducesValidationProblem(StatusCodes.Status400BadRequest, "application/problem+json")] diff --git a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.AcceptsAttribute.verified.txt b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.AcceptsAttribute.verified.txt index d4d0ac4..ed86e13 100644 --- a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.AcceptsAttribute.verified.txt +++ b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.AcceptsAttribute.verified.txt @@ -19,14 +19,9 @@ namespace Microsoft.AspNetCore.Generated.Attributes; internal sealed class AcceptsAttribute : global::System.Attribute { /// - /// Gets the request type accepted by the endpoint. + /// Gets the CLR type of the endpoint filter. /// - public global::System.Type? RequestType { get; init; } - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } + public global::System.Type Type { get; } /// /// Gets the primary content type accepted by the endpoint. @@ -38,15 +33,22 @@ internal sealed class AcceptsAttribute : global::System.Attribute /// public string[] AdditionalContentTypes { get; } + /// + /// Gets a value indicating whether the request body is optional. + /// + public bool IsOptional { get; init; } + /// /// Initializes a new instance of the class. /// + /// The CLR type of the request body. /// The primary content type accepted by the endpoint. /// Additional content types accepted by the endpoint. - public AcceptsAttribute(string contentType = "application/json", params string[] additionalContentTypes) + public AcceptsAttribute(global::System.Type type, string contentType = "application/json", params string[] additionalContentTypes) { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; + Type = type; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes; } } @@ -58,14 +60,9 @@ internal sealed class AcceptsAttribute : global::System.Attribute internal sealed class AcceptsAttribute : global::System.Attribute { /// - /// Gets the request type accepted by the endpoint. + /// Gets the CLR type of the endpoint filter. /// - public global::System.Type RequestType => typeof(TRequest); - - /// - /// Gets a value indicating whether the request body is optional. - /// - public bool IsOptional { get; init; } + public global::System.Type Type => typeof(TRequest); /// /// Gets the primary content type accepted by the endpoint. @@ -77,6 +74,11 @@ internal sealed class AcceptsAttribute : global::System.Attribute /// public string[] AdditionalContentTypes { get; } + /// + /// Gets a value indicating whether the request body is optional. + /// + public bool IsOptional { get; init; } + /// /// Initializes a new instance of the generic Accepts attribute class. /// @@ -84,7 +86,7 @@ internal sealed class AcceptsAttribute : global::System.Attribute /// Additional content types accepted by the endpoint. public AcceptsAttribute(string contentType = "application/json", params string[] additionalContentTypes) { - ContentType = string.IsNullOrWhiteSpace(contentType) ? "application/json" : contentType; - AdditionalContentTypes = additionalContentTypes ?? []; + ContentType = contentType; + AdditionalContentTypes = additionalContentTypes; } } diff --git a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.EndpointFilterAttribute.verified.txt b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.EndpointFilterAttribute.verified.txt index e96674f..7e69fab 100644 --- a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.EndpointFilterAttribute.verified.txt +++ b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.EndpointFilterAttribute.verified.txt @@ -21,15 +21,15 @@ internal sealed class EndpointFilterAttribute : global::System.Attribute /// /// Gets the CLR type of the endpoint filter. /// - public global::System.Type FilterType { get; } + public global::System.Type Type { get; } /// /// Initializes a new instance of the class. /// - /// The CLR type of the endpoint filter. - public EndpointFilterAttribute(global::System.Type filterType) + /// The CLR type of the endpoint filter. + public EndpointFilterAttribute(global::System.Type type) { - FilterType = filterType ?? throw new global::System.ArgumentNullException(nameof(filterType)); + Type = type; } } @@ -43,5 +43,5 @@ internal sealed class EndpointFilterAttribute : global::System.Attribut /// /// Gets the CLR type of the endpoint filter. /// - public global::System.Type FilterType => typeof(TFilter); + public global::System.Type Type => typeof(TFilter); } diff --git a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.ProducesResponseAttribute.verified.txt b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.ProducesResponseAttribute.verified.txt index e5d6748..25e39fb 100644 --- a/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.ProducesResponseAttribute.verified.txt +++ b/tests/GeneratedEndpoints.Tests/AttributeGenerationTests.ProducesResponseAttribute.verified.txt @@ -18,10 +18,10 @@ namespace Microsoft.AspNetCore.Generated.Attributes; [global::System.AttributeUsage(global::System.AttributeTargets.Method, Inherited = false, AllowMultiple = true)] internal sealed class ProducesResponseAttribute : global::System.Attribute { -/// -/// Gets the response type produced by the endpoint. -/// -public global::System.Type? ResponseType { get; init; } + /// + /// Gets the response type produced by the endpoint. + /// + public global::System.Type Type { get; } /// /// Gets the HTTP status code returned by the endpoint. @@ -41,11 +41,13 @@ public global::System.Type? ResponseType { get; init; } /// /// Initializes a new instance of the class. /// + /// The CLR type of the response body. /// The HTTP status code returned by the endpoint. /// The primary content type produced by the endpoint. /// Additional content types produced by the endpoint. - public ProducesResponseAttribute(int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) + public ProducesResponseAttribute(global::System.Type type, int statusCode = global::Microsoft.AspNetCore.Http.StatusCodes.Status200OK, string? contentType = null, params string[] additionalContentTypes) { + Type = type; StatusCode = statusCode; ContentType = contentType; AdditionalContentTypes = additionalContentTypes ?? []; @@ -62,7 +64,7 @@ internal sealed class ProducesResponseAttribute : global::System.Attr /// /// Gets the response type produced by the endpoint. /// - public global::System.Type ResponseType => typeof(TResponse); + public global::System.Type Type => typeof(TResponse); /// /// Gets the HTTP status code returned by the endpoint. diff --git a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs index 4fef5bd..035ce8d 100644 --- a/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs +++ b/tests/GeneratedEndpoints.Tests/Common/SourceFactory.cs @@ -322,6 +322,7 @@ public static string BuildContractsAndBindingSource( bool includeAccepts, bool includeGenericAccepts, bool includeProducesResponse, + bool includeGenericProducesResponse, bool includeProducesProblem, bool includeProducesValidationProblem, bool includeSummaryAndDescription, @@ -378,11 +379,12 @@ public static string BuildContractsAndBindingSource( if (includeProducesResponse) { var secondProduces = string.IsNullOrWhiteSpace(producesContentType2) ? "" : $", \"{producesContentType2}\""; - builder.AppendLine( - $" [ProducesResponse(200, \"{producesContentType1 ?? "application/json"}\"{secondProduces}, ResponseType = typeof(ResponseRecord))]" - ); + builder.AppendLine($" [ProducesResponse(typeof(ResponseRecord), 200, \"{producesContentType1 ?? "application/json"}\"{secondProduces})]"); } + if (includeGenericProducesResponse) + builder.AppendLine($" [ProducesResponse(200, \"{producesContentType1 ?? "application/json"}\")]"); + if (includeProducesProblem) builder.AppendLine($" [ProducesProblem(500, \"{producesContentType1 ?? "application/problem+json"}\")]"); diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.AuthorizationAndMetadataMatrix_6F94B85A7155_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.AuthorizationAndMetadataMatrix_6F94B85A7155_MapEndpointHandlers.verified.txt index ed92292..ed5e8f5 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.AuthorizationAndMetadataMatrix_6F94B85A7155_MapEndpointHandlers.verified.txt +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.AuthorizationAndMetadataMatrix_6F94B85A7155_MapEndpointHandlers.verified.txt @@ -29,7 +29,6 @@ internal static class EndpointRouteBuilderExtensions .RequireAuthorization("MethodPolicy") .RequireCors("MethodCors") .RequireHost("services.contoso.com", "contoso.com") - .AllowAnonymous() .DisableRequestTimeout(); return builder; diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_AddEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_AddEndpointHandlers.verified.txt new file mode 100644 index 0000000..7a6643b --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_AddEndpointHandlers.verified.txt @@ -0,0 +1,24 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointServicesExtensions +{ + internal static void AddEndpointHandlers(this IServiceCollection services) + { + services.TryAddScoped(); + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_MapEndpointHandlers.verified.txt new file mode 100644 index 0000000..c9e81bb --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_0BE3EC6390D4_MapEndpointHandlers.verified.txt @@ -0,0 +1,35 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointRouteBuilderExtensions +{ + internal static IEndpointRouteBuilder MapEndpointHandlers(this IEndpointRouteBuilder builder) + { + builder.MapGet("/contracts/{id:int}", global::ContractEndpoints.Handle) + .WithName("Handle") + .WithDisplayName("Contract endpoint") + .WithTags("Contracts", "Bindings") + .Produces(200, "application/problem+json") + .ProducesProblem(500, "application/problem+json") + .AllowAnonymous(); + + return builder; + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_AddEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_AddEndpointHandlers.verified.txt new file mode 100644 index 0000000..7a6643b --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_AddEndpointHandlers.verified.txt @@ -0,0 +1,24 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointServicesExtensions +{ + internal static void AddEndpointHandlers(this IServiceCollection services) + { + services.TryAddScoped(); + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_MapEndpointHandlers.verified.txt new file mode 100644 index 0000000..d535349 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_2F85DF1A025B_MapEndpointHandlers.verified.txt @@ -0,0 +1,35 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointRouteBuilderExtensions +{ + internal static IEndpointRouteBuilder MapEndpointHandlers(this IEndpointRouteBuilder builder) + { + builder.MapGet("/contracts/{id:int}", global::ContractEndpoints.Handle) + .WithName("Handle") + .WithDisplayName("Contract endpoint") + .ExcludeFromDescription() + .Produces(200, "application/json") + .ProducesProblem(500, "application/json") + .RequireAuthorization("ContractsPolicy"); + + return builder; + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_8330CA9A1CFC_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_8330CA9A1CFC_MapEndpointHandlers.verified.txt index 493d722..311982e 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_8330CA9A1CFC_MapEndpointHandlers.verified.txt +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_8330CA9A1CFC_MapEndpointHandlers.verified.txt @@ -29,8 +29,7 @@ internal static class EndpointRouteBuilderExtensions .WithDescription("Shows binding and contract combinations.") .Accepts("application/json") .ProducesValidationProblem(422, "application/problem+json") - .RequireAuthorization("ContractsPolicy") - .AllowAnonymous(); + .RequireAuthorization("ContractsPolicy"); return builder; } diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_9F5FE6E1F139_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_9F5FE6E1F139_MapEndpointHandlers.verified.txt index 70f58ad..74ba401 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_9F5FE6E1F139_MapEndpointHandlers.verified.txt +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_9F5FE6E1F139_MapEndpointHandlers.verified.txt @@ -30,6 +30,7 @@ internal static class EndpointRouteBuilderExtensions .WithTags("Contracts", "Bindings") .Accepts("application/xml") .Produces(200, "application/json", "text/json") + .Produces(200, "application/json") .ProducesProblem(500, "application/json") .ProducesValidationProblem(422, "application/json") .AllowAnonymous(); diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_AddEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_AddEndpointHandlers.verified.txt new file mode 100644 index 0000000..6e79c25 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_AddEndpointHandlers.verified.txt @@ -0,0 +1,24 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointServicesExtensions +{ + internal static void AddEndpointHandlers(this IServiceCollection services) + { + services.TryAddScoped(); + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_MapEndpointHandlers.verified.txt new file mode 100644 index 0000000..a5b8b0d --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.ContractsAndBindingMatrix_E255BF1E7B54_MapEndpointHandlers.verified.txt @@ -0,0 +1,36 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointRouteBuilderExtensions +{ + internal static IEndpointRouteBuilder MapEndpointHandlers(this IEndpointRouteBuilder builder) + { + builder.MapGet("/contracts/{id:int}", global::GeneratedEndpointsTests.ContractEndpoints.Handle) + .WithName("Handle") + .WithDisplayName("Contract endpoint") + .WithTags("Contracts", "Bindings") + .Accepts("application/json") + .Produces(200, "application/json") + .ProducesValidationProblem(422, "application/problem+json") + .AllowAnonymous(); + + return builder; + } +} diff --git a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs index dd38cbb..725fa30 100644 --- a/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs +++ b/tests/GeneratedEndpoints.Tests/GeneratedSourceTests.cs @@ -171,17 +171,19 @@ await result.VerifyAsync("MapEndpointHandlers.g.cs") } [Theory] - [InlineData(true, true, true, true, true, true, true, true, true, true, true, true, true, false, true, false, "application/xml", "text/xml", + [InlineData(true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, true, false, "application/xml", "text/xml", "application/json", "text/json" )] - [InlineData(false, false, true, false, false, true, false, true, true, false, false, false, true, true, false, true, "application/custom", null, + [InlineData(false, false, true, false, false, true, false, true, false, true, false, false, true, true, false, true, false, "application/custom", null, "application/problem+json", null )] - [InlineData(true, true, false, true, true, false, true, false, false, true, true, true, false, false, true, true, null, null, null, null)] - [InlineData(false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, "application/xml", null, + [InlineData(true, true, false, true, true, false, true, false, false, false, true, true, true, false, false, true, true, null, null, null, null)] + [InlineData(false, true, false, true, false, true, false, true, false, true, false, false, true, false, true, false, true, "application/xml", null, "application/json", null )] - [InlineData(true, false, true, false, true, false, true, false, true, false, true, false, true, false, true, false, null, "text/plain", null, "text/plain")] + [InlineData(true, false, true, false, true, false, true, false, true, false, true, false, true, true, false, true, false, null, "text/plain", null, + "text/plain" + )] public async Task ContractsAndBindingMatrix( bool withNamespace, bool includeBindingNames, @@ -191,6 +193,7 @@ public async Task ContractsAndBindingMatrix( bool includeAccepts, bool includeGenericAccepts, bool includeProducesResponse, + bool includeGenericProducesResponse, bool includeProducesProblem, bool includeProducesValidationProblem, bool includeSummaryAndDescription, @@ -206,9 +209,9 @@ public async Task ContractsAndBindingMatrix( ) { var source = SourceFactory.BuildContractsAndBindingSource(includeBindingNames, includeAsParameters, includeFromServices, includeFromKeyedServices, - includeAccepts, includeGenericAccepts, includeProducesResponse, includeProducesProblem, includeProducesValidationProblem, - includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, methodRequiresAuthorization, - acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 + includeAccepts, includeGenericAccepts, includeProducesResponse, includeGenericProducesResponse, includeProducesProblem, + includeProducesValidationProblem, includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, + methodRequiresAuthorization, acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 ); var sources = TestHelpers.GetSources(source, withNamespace); diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_AddEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_AddEndpointHandlers.verified.txt new file mode 100644 index 0000000..6e79c25 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_AddEndpointHandlers.verified.txt @@ -0,0 +1,24 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointServicesExtensions +{ + internal static void AddEndpointHandlers(this IServiceCollection services) + { + services.TryAddScoped(); + } +} diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_MapEndpointHandlers.verified.txt b/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_MapEndpointHandlers.verified.txt new file mode 100644 index 0000000..8a60c59 --- /dev/null +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.GenericProducesResponseAttribute_MapEndpointHandlers.verified.txt @@ -0,0 +1,31 @@ +//----------------------------------------------------------------------------- +// +// This code was generated by MinimalApiGenerator which can be found +// in the GeneratedEndpoints namespace. +// +// Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated. +// +//----------------------------------------------------------------------------- + +#nullable enable + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Generated.Routing; + +internal static class EndpointRouteBuilderExtensions +{ + internal static IEndpointRouteBuilder MapEndpointHandlers(this IEndpointRouteBuilder builder) + { + builder.MapGet("/contracts/{id:int}", global::GeneratedEndpointsTests.ContractEndpoints.Handle) + .WithName("Handle") + .Produces(200, "application/json"); + + return builder; + } +} diff --git a/tests/GeneratedEndpoints.Tests/IndividualTests.cs b/tests/GeneratedEndpoints.Tests/IndividualTests.cs index be6ee72..bd0d6c3 100644 --- a/tests/GeneratedEndpoints.Tests/IndividualTests.cs +++ b/tests/GeneratedEndpoints.Tests/IndividualTests.cs @@ -378,6 +378,13 @@ public async Task ProducesResponseAttribute() await VerifyIndividualAsync(source, nameof(ProducesResponseAttribute)); } + [Fact] + public async Task GenericProducesResponseAttribute() + { + var source = ContractScenario(includeGenericProducesResponse: true, producesContentType1: "application/json"); + await VerifyIndividualAsync(source, nameof(GenericProducesResponseAttribute)); + } + [Fact] public async Task ProducesResponseMultipleContentTypes() { @@ -546,6 +553,7 @@ private static string ContractScenario( bool includeAccepts = false, bool includeGenericAccepts = false, bool includeProducesResponse = false, + bool includeGenericProducesResponse = false, bool includeProducesProblem = false, bool includeProducesValidationProblem = false, bool includeSummaryAndDescription = false, @@ -561,9 +569,9 @@ private static string ContractScenario( ) { return SourceFactory.BuildContractsAndBindingSource(includeBindingNames, includeAsParameters, includeFromServices, includeFromKeyedServices, - includeAccepts, includeGenericAccepts, includeProducesResponse, includeProducesProblem, includeProducesValidationProblem, - includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, methodRequiresAuthorization, - acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 + includeAccepts, includeGenericAccepts, includeProducesResponse, includeGenericProducesResponse, includeProducesProblem, + includeProducesValidationProblem, includeSummaryAndDescription, includeDisplayName, includeTags, excludeFromDescription, allowAnonymous, + methodRequiresAuthorization, acceptsContentType1, acceptsContentType2, producesContentType1, producesContentType2 ); }