Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 8 additions & 45 deletions src/GeneratedEndpoints/AddEndpointHandlersGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Text;
using System.Collections.Immutable;
using System.Text;
using GeneratedEndpoints.Common;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Text;
Expand All @@ -12,12 +13,15 @@ namespace GeneratedEndpoints;

internal static class AddEndpointHandlersGenerator
{
public static void GenerateSource(SourceProductionContext context, EquatableImmutableArray<RequestHandler> requestHandlers)
public static void GenerateSource(SourceProductionContext context, ImmutableSortedDictionary<RequestHandlerClass, ImmutableArray<RequestHandler>> grouped)
{
context.CancellationToken.ThrowIfCancellationRequested();

var nonStaticClassNames = GetDistinctNonStaticClassNames(requestHandlers);
var source = GetAddEndpointHandlersStringBuilder(nonStaticClassNames);
var nonStaticClassNames = grouped.Keys
.Where(x => !x.IsStatic)
.Select(x => x.Name)
.ToList();
var source = new StringBuilder();
source.AppendLine(FileHeader);

source.AppendLine();
Expand Down Expand Up @@ -61,45 +65,4 @@ public static void GenerateSource(SourceProductionContext context, EquatableImmu
var sourceText = StringBuilderPool.ToStringAndReturn(source);
context.AddSource(AddEndpointHandlersMethodHint, SourceText.From(sourceText, Encoding.UTF8));
}

private static List<string> GetDistinctNonStaticClassNames(EquatableImmutableArray<RequestHandler> requestHandlers)
{
var classNames = new List<string>();
if (requestHandlers.Count == 0)
return classNames;

var seen = new HashSet<string>(StringComparer.Ordinal);
for (var index = 0; index < requestHandlers.Count; index++)
{
var requestHandler = requestHandlers[index];
if (requestHandler.Class.IsStatic)
continue;

var className = requestHandler.Class.Name;
if (seen.Add(className))
classNames.Add(className);
}

return classNames;
}

private static StringBuilder GetAddEndpointHandlersStringBuilder(List<string> nonStaticClassNames)
{
var estimate = 512L;
for (var index = 0; index < nonStaticClassNames.Count; index++)
{
var className = nonStaticClassNames[index];
estimate += 36 + className.Length;
}

estimate += Math.Max(256, nonStaticClassNames.Count * 12);
estimate = (long)(estimate * 1.10);

if (estimate < 512)
estimate = 512;
else if (estimate > int.MaxValue)
estimate = int.MaxValue;

return StringBuilderPool.Get((int)estimate);
}
}
8 changes: 8 additions & 0 deletions src/GeneratedEndpoints/Common/AttributeDataExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ internal static class AttributeDataExtensions
return null;
}

public static ITypeSymbol? GetConstructorTypeSymbol(this AttributeData attribute, int position = 0)
{
if (attribute.ConstructorArguments.Length > position && attribute.ConstructorArguments[position].Value is ITypeSymbol typeSymbol)
return typeSymbol;

return null;
}

public static ITypeSymbol? GetNamedTypeSymbol(this AttributeData attribute, string namedParameter)
{
foreach (var namedArg in attribute.NamedArguments)
Expand Down
64 changes: 34 additions & 30 deletions src/GeneratedEndpoints/Common/Constants.GeneratedSources.cs

Large diffs are not rendered by default.

20 changes: 1 addition & 19 deletions src/GeneratedEndpoints/Common/Constants.cs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
using System.ComponentModel;

namespace GeneratedEndpoints.Common;

internal static partial class Constants
{
internal const string FallbackHttpMethod = "__FALLBACK__";

internal const string NameAttributeNamedParameter = "Name";
internal const string ResponseTypeAttributeNamedParameter = "ResponseType";
internal const string RequestTypeAttributeNamedParameter = "RequestType";
internal const string IsOptionalAttributeNamedParameter = "IsOptional";

internal const string RequireAuthorizationAttributeName = "RequireAuthorizationAttribute";
Expand Down Expand Up @@ -47,12 +43,6 @@ internal static partial class Constants
internal const string SummaryAttributeName = "SummaryAttribute";
internal const string SummaryAttributeHint = $"{SummaryAttributeFullyQualifiedName}.gs.cs";

internal const string DisplayNameAttributeName = nameof(DisplayNameAttribute);
internal const string DescriptionAttributeName = nameof(DescriptionAttribute);
internal const string AllowAnonymousAttributeName = "AllowAnonymousAttribute";
internal const string TagsAttributeName = "TagsAttribute";
internal const string ExcludeFromDescriptionAttributeName = "ExcludeFromDescriptionAttribute";

internal const string EndpointFilterAttributeName = "EndpointFilterAttribute";
internal const string EndpointFilterAttributeHint = $"{EndpointFilterAttributeFullyQualifiedName}.gs.cs";

Expand Down Expand Up @@ -82,15 +72,7 @@ internal static partial class Constants
internal const string AsyncSuffix = "Async";
internal const string ApplicationJsonContentType = "application/json";
internal const string GlobalPrefix = "global::";
internal const string Dot = ".";

internal static readonly string[] AttributesNamespaceParts = AttributesNamespace.Split('.');
internal static readonly string[] AspNetCoreHttpNamespaceParts = ["Microsoft", "AspNetCore", "Http"];
internal static readonly string[] AspNetCoreMvcNamespaceParts = ["Microsoft", "AspNetCore", "Mvc"];
internal static readonly string[] AspNetCoreAuthorizationNamespaceParts = ["Microsoft", "AspNetCore", "Authorization"];
internal static readonly string[] AspNetCoreRoutingNamespaceParts = ["Microsoft", "AspNetCore", "Routing"];
internal static readonly string[] ExtensionsDependencyInjectionNamespaceParts = ["Microsoft", "Extensions", "DependencyInjection"];
internal static readonly string[] ComponentModelNamespaceParts = ["System", "ComponentModel"];

private const string BaseNamespace = "Microsoft.AspNetCore.Generated";
private const string AttributesNamespace = $"{BaseNamespace}.Attributes";
private const string RequireAuthorizationAttributeFullyQualifiedName = $"{AttributesNamespace}.{RequireAuthorizationAttributeName}";
Expand Down
4 changes: 1 addition & 3 deletions src/GeneratedEndpoints/Common/EndpointConfiguration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,5 @@ internal readonly record struct EndpointConfiguration
public required bool WithRequestTimeout { get; init; }
public required string? RequestTimeoutPolicyName { get; init; }
public required int? Order { get; init; }
public required string? GroupIdentifier { get; init; }
public required string? GroupPattern { get; init; }
public required string? GroupName { get; init; }
public required EndpointGroup? Group { get; init; }
}
67 changes: 31 additions & 36 deletions src/GeneratedEndpoints/Common/EndpointConfigurationFactory.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
using System.Runtime.CompilerServices;
using Microsoft.CodeAnalysis;
using static GeneratedEndpoints.Common.Constants;

namespace GeneratedEndpoints.Common;

internal static class EndpointConfigurationFactory
{
private static readonly ConditionalWeakTable<INamedTypeSymbol, GeneratedAttributeKindCacheEntry> GeneratedAttributeKindCache = new();

public static EndpointConfiguration Create(ISymbol symbol)
{
var attributes = symbol.GetAttributes();
Expand Down Expand Up @@ -48,7 +45,7 @@ public static EndpointConfiguration Create(ISymbol symbol)
if (attributeClass is null)
continue;

var attributeKind = GetGeneratedAttributeKind(attributeClass);
var attributeKind = attributeClass.OriginalDefinition.GetRequestHandlerAttributeKind();
switch (attributeKind)
{
case RequestHandlerAttributeKind.ShortCircuit:
Expand Down Expand Up @@ -172,9 +169,14 @@ public static EndpointConfiguration Create(ISymbol symbol)
WithRequestTimeout = withRequestTimeout ?? false,
RequestTimeoutPolicyName = requestTimeoutPolicyName,
Order = order,
GroupIdentifier = groupIdentifier,
GroupPattern = groupPattern,
GroupName = groupName,
Group = groupIdentifier is not null && groupPattern is not null
? new EndpointGroup
{
Identifier = groupIdentifier,
Pattern = groupPattern,
Name = groupName,
}
: null,
};
}

Expand All @@ -198,16 +200,6 @@ public static EndpointConfiguration Create(ISymbol symbol)
return StringBuilderPool.ToStringAndReturn(builder);
}

private static RequestHandlerAttributeKind GetGeneratedAttributeKind(INamedTypeSymbol attributeClass)
{
var definition = attributeClass.OriginalDefinition;
var cacheEntry = GeneratedAttributeKindCache.GetValue(
definition, static def => new GeneratedAttributeKindCacheEntry(def.GetRequestHandlerAttributeKind())
);

return cacheEntry.Kind;
}

private static EquatableImmutableArray<T>? ToEquatableOrNull<T>(List<T>? values)
{
return values is { Count: > 0 } ? values.ToEquatableImmutableArray() : null;
Expand All @@ -219,9 +211,11 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym
if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 })
requestType = attributeClass.TypeArguments[0]
.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
else if (attribute.GetNamedTypeSymbol(RequestTypeAttributeNamedParameter) is { } requestTypeSymbol)
requestType = requestTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
else
requestType = attribute.GetConstructorTypeSymbol()
?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);

if (requestType is null)
return;

var contentType = attribute.GetConstructorStringValue() ?? ApplicationJsonContentType;
Expand All @@ -236,23 +230,29 @@ private static void TryAddAcceptsMetadata(AttributeData attribute, INamedTypeSym

private static void TryAddProducesMetadata(AttributeData attribute, INamedTypeSymbol attributeClass, ref List<ProducesMetadata>? produces)
{
string? responseType;
ProducesMetadata? producesMetadata;
if (attributeClass is { IsGenericType: true, TypeArguments.Length: 1 })
responseType = attributeClass.TypeArguments[0]
{
var responseType = attributeClass.TypeArguments[0]
.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
else if (attribute.GetNamedTypeSymbol(ResponseTypeAttributeNamedParameter) is { } responseTypeSymbol)
responseType = responseTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var statusCode = attribute.GetConstructorIntValue(0) ?? 200;
var contentType = attribute.GetConstructorStringValue(1);
var additionalContentTypes = attribute.GetConstructorStringArray(2);
producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes);
}
else
return;

var statusCode = attribute.GetConstructorIntValue() ?? 200;
var contentType = attribute.GetConstructorStringValue(1);
var additionalContentTypes = attribute.GetConstructorStringArray(2);

var producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes);
{
var responseType = attribute.GetConstructorTypeSymbol()
?.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)
?? "";
var statusCode = attribute.GetConstructorIntValue(1) ?? 200;
var contentType = attribute.GetConstructorStringValue(2);
var additionalContentTypes = attribute.GetConstructorStringArray(3);
producesMetadata = new ProducesMetadata(responseType, statusCode, contentType, additionalContentTypes);
}

var producesList = produces ??= [];
producesList.Add(producesMetadata);
producesList.Add(producesMetadata.Value);
}

private static void TryAddEndpointFilter(
Expand Down Expand Up @@ -291,9 +291,4 @@ private static void TryAddEndpointFilterType(ITypeSymbol? typeSymbol, ref List<s
endpointFilters ??= [];
endpointFilters.Add(displayString);
}

private sealed class GeneratedAttributeKindCacheEntry(RequestHandlerAttributeKind kind)
{
public RequestHandlerAttributeKind Kind { get; } = kind;
}
}
8 changes: 8 additions & 0 deletions src/GeneratedEndpoints/Common/EndpointGroup.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
namespace GeneratedEndpoints.Common;

internal readonly record struct EndpointGroup
{
public required string Identifier { get; init; }
public required string Pattern { get; init; }
public required string? Name { get; init; }
}
22 changes: 1 addition & 21 deletions src/GeneratedEndpoints/Common/MethodSymbolExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using static GeneratedEndpoints.Common.AttributeSymbolMatcher;
using static GeneratedEndpoints.Common.Constants;

namespace GeneratedEndpoints.Common;
Expand Down Expand Up @@ -45,7 +44,7 @@ private static string GetBindingPrefix(IParameterSymbol parameter)
if (attributeClass is null)
continue;

var attributeSource = GetBindingSourceFromAttributeClass(attributeClass);
var attributeSource = attributeClass.GetBindingSource();
if (attributeSource == BindingSource.None)
continue;

Expand All @@ -69,25 +68,6 @@ private static string GetBindingPrefix(IParameterSymbol parameter)
return bindingPrefix;
}

private static BindingSource GetBindingSourceFromAttributeClass(INamedTypeSymbol attributeClass)
{
var definition = attributeClass.OriginalDefinition;
var namespaceSymbol = definition.ContainingNamespace;

return definition.Name switch
{
"FromRouteAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromRoute,
"FromQueryAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromQuery,
"FromHeaderAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromHeader,
"FromBodyAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromBody,
"FromFormAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromForm,
"FromServicesAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreMvcNamespaceParts) => BindingSource.FromServices,
"FromKeyedServicesAttribute" when IsInNamespace(namespaceSymbol, ExtensionsDependencyInjectionNamespaceParts) => BindingSource.FromKeyedServices,
"AsParametersAttribute" when IsInNamespace(namespaceSymbol, AspNetCoreHttpNamespaceParts) => BindingSource.AsParameters,
_ => BindingSource.None,
};
}

private static string GetBindingSourceAttribute(BindingSource source, TypedConstant? typedKey, string? bindingName)
{
switch (source)
Expand Down
51 changes: 0 additions & 51 deletions src/GeneratedEndpoints/Common/NamedTypeSymbolExtensions.cs

This file was deleted.

Loading