Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,22 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor
private IEntityType? _entityType;

// Extract MethodInfo via expression trees (trim-safe; computed once per AppDomain)
private static readonly MethodInfo _select =
private readonly static MethodInfo _select =
((MethodCallExpression)((Expression<Func<IQueryable<object>, IQueryable<object>>>)
(q => q.Select(x => x))).Body).Method.GetGenericMethodDefinition();

private static readonly MethodInfo _where =
private readonly static MethodInfo _where =
((MethodCallExpression)((Expression<Func<IQueryable<object>, IQueryable<object>>>)
(q => q.Where(x => true))).Body).Method.GetGenericMethodDefinition();

// Static caches — keyed by CLR type, shared across all instances for the AppDomain lifetime.
// ConditionalWeakTable uses "ephemeron" semantics: the Type key is not kept alive by the
// cache entry, so types from collectible AssemblyLoadContexts can still be unloaded.
private readonly static ConditionalWeakTable<Type, StrongBox<bool>> _compilerGeneratedClosureCache = new();
private readonly static ConditionalWeakTable<Type, PropertyInfo[]> _projectablePropertiesCache = new();
private readonly static ConditionalWeakTable<Type, MethodInfo> _closedSelectCache = new();
private readonly static ConditionalWeakTable<Type, MethodInfo> _closedWhereCache = new();

public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver, bool trackByDefault = false)
{
_trackingByDefault = trackByDefault;
Expand Down Expand Up @@ -84,7 +92,6 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
// // case of a first()
// return obj.MyMap(x => new Obj {});
// }


if (call.Method.ReturnType.IsAssignableTo(typeof(IQueryable)))
{
Expand All @@ -101,7 +108,8 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
// before the query become executed by EF (before the .First()), we rewrite the .First(where)
// as .Where(where).Select(x => ...).First()

var where = Expression.Call(null, _where.MakeGenericMethod(_entityType.ClrType), call.Arguments);
var whereMethod = _closedWhereCache.GetValue(_entityType.ClrType, t => _where.MakeGenericMethod(t));
var where = Expression.Call(null, whereMethod, call.Arguments);
// The call instance is based on the wrong polymorphied method.
var first = call.Method.DeclaringType?.GetMethods()
.FirstOrDefault(x => x.Name == call.Method.Name && x.GetParameters().Length == 1);
Expand Down Expand Up @@ -138,18 +146,27 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
protected override Expression VisitMethodCall(MethodCallExpression node)
{
// Replace MethodGroup arguments with their reflected expressions.
// Note that MethodCallExpression.Update returns the original Expression if argument values have not changed.
node = node.Update(node.Object, node.Arguments.Select(arg => arg switch {
UnaryExpression {
NodeType: ExpressionType.Convert,
Operand: MethodCallExpression {
NodeType: ExpressionType.Call,
Method: { Name: nameof(MethodInfo.CreateDelegate), DeclaringType.Name: nameof(MethodInfo) },
Object: ConstantExpression { Value: MethodInfo methodInfo }
}
} => TryGetReflectedExpression(methodInfo, out var expressionArg) ? expressionArg : arg,
_ => arg
}));
// No-alloc fast-path: scan args without allocating; only copy the array and call
// Update() when a replacement is actually found (method-group arguments are rare).
Expression[]? updatedArgs = null;
for (var i = 0; i < node.Arguments.Count; i++)
{
if (node.Arguments[i] is UnaryExpression {
NodeType: ExpressionType.Convert,
Operand: MethodCallExpression {
NodeType: ExpressionType.Call,
Method: { Name: nameof(MethodInfo.CreateDelegate), DeclaringType.Name: nameof(MethodInfo) },
Object: ConstantExpression { Value: MethodInfo capturedMethodInfo }
}
} && TryGetReflectedExpression(capturedMethodInfo, out var expressionArg))
{
(updatedArgs ??= [.. node.Arguments])[i] = expressionArg;
}
}
if (updatedArgs is not null)
{
node = node.Update(node.Object, updatedArgs);
}

// Get the overriding methodInfo based on te type of the received of this expression
var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method;
Expand All @@ -172,7 +189,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
{
for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++)
{
var parameterExpession = reflectedExpression.Parameters[parameterIndex];
var parameterExpression = reflectedExpression.Parameters[parameterIndex];
var mappedArgumentExpression = (parameterIndex, node.Object) switch {
(0, not null) => node.Object,
(_, not null) => node.Arguments[parameterIndex - 1],
Expand All @@ -181,7 +198,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node)

if (mappedArgumentExpression is not null)
{
_expressionArgumentReplacer.ParameterArgumentMapping.Add(parameterExpession, mappedArgumentExpression);
_expressionArgumentReplacer.ParameterArgumentMapping.Add(parameterExpression, mappedArgumentExpression);
}
}

Expand Down Expand Up @@ -232,19 +249,35 @@ protected override Expression VisitMember(MemberExpression node)
{
// Evaluate captured variables in closures that contain EF queries to inline them into the main query
if (node.Expression is ConstantExpression constant &&
constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) &&
Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true))
IsCompilerGeneratedClosure(constant.Type))
{
try
{
var value = Expression
.Lambda<Func<object>>(Expression.Convert(node, typeof(object)))
.Compile()
.Invoke();
// Cheap type check first: only call GetValue() when the declared type
// could possibly hold an IQueryable at runtime. We use IEnumerable as
// the gate (rather than IQueryable) because a variable legitimately
// declared as IEnumerable<T> may hold an EF Core IQueryable<T> at
// runtime — both interfaces share the same assignability chain.
// FieldType / PropertyType are free property reads on already-
// materialised MemberInfo objects, so this check is cheap.
var memberType = node.Member switch {
FieldInfo field => field.FieldType,
PropertyInfo prop => prop.PropertyType,
_ => null
};

if (value is IQueryable queryable && ReferenceEquals(queryable.Provider, _currentQueryProvider))
if (memberType is not null && typeof(IEnumerable).IsAssignableFrom(memberType))
{
return Visit(queryable.Expression);
var value = node.Member switch {
FieldInfo field => field.GetValue(constant.Value),
PropertyInfo prop => prop.GetValue(constant.Value),
_ => null
};

if (value is IQueryable queryable && ReferenceEquals(queryable.Provider, _currentQueryProvider))
{
return Visit(queryable.Expression);
}
}
}
catch
Expand Down Expand Up @@ -275,16 +308,10 @@ PropertyInfo property when nodeExpression is not null
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();

return base.Visit(
updatedBody
);
}
else
{
return base.Visit(
reflectedExpression.Body
);
return base.Visit(updatedBody);
}

return base.Visit(reflectedExpression.Body);
}

return base.VisitMember(node);
Expand All @@ -303,12 +330,13 @@ protected override Expression VisitExtension(Expression node)

private Expression _AddProjectableSelect(Expression node, IEntityType entityType)
{
var projectableProperties = entityType.ClrType.GetProperties()
.Where(x => x.IsDefined(typeof(ProjectableAttribute), false))
.Where(x => x.CanWrite)
.ToList();
var projectableProperties = _projectablePropertiesCache.GetValue(
entityType.ClrType,
static t => t.GetProperties()
.Where(x => x.IsDefined(typeof(ProjectableAttribute), false) && x.CanWrite)
.ToArray());

if (!projectableProperties.Any())
if (projectableProperties.Length == 0)
{
return node;
}
Expand All @@ -327,7 +355,7 @@ private Expression _AddProjectableSelect(Expression node, IEntityType entityType
.Where(x => projectableProperties.All(y => x.Name != y.Name && x.Name != $"<{y.Name}>k__BackingField"));

// Replace db.Entities to db.Entities.Select(x => new Entity { Property1 = x.Property1, Rewritted = rewrittedProperty })
var select = _select.MakeGenericMethod(entityType.ClrType, entityType.ClrType);
var select = _closedSelectCache.GetValue(entityType.ClrType, t => _select.MakeGenericMethod(t, t));
var xParam = Expression.Parameter(entityType.ClrType);
return Expression.Call(
null,
Expand All @@ -354,5 +382,12 @@ private Expression _GetAccessor(PropertyInfo property, ParameterExpression para)
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
return base.Visit(updatedBody);
}

private static bool IsCompilerGeneratedClosure(Type type) =>
// TypeAttributes.NestedPrivate is a cheap flag check that rules out most types before
// touching the attribute cache.
type.Attributes.HasFlag(TypeAttributes.NestedPrivate) &&
_compilerGeneratedClosureCache.GetValue(type, static t =>
new StrongBox<bool>(Attribute.IsDefined(t, typeof(CompilerGeneratedAttribute), inherit: true))).Value;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT [e].[Id], (
SELECT COUNT(*)
FROM [Entity] AS [e0]
WHERE [e0].[Id] * 2 > 4) AS [SubsetCount]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT [e].[Id], (
SELECT COUNT(*)
FROM [Entity] AS [e0]
WHERE [e0].[Id] * 2 > 4) AS [SubsetCount]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT [e].[Id], (
SELECT COUNT(*)
FROM [Entity] AS [e0]
WHERE [e0].[Id] * 2 > 4) AS [SubsetCount]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @lowerBound int = 3;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @lowerBound AND [e].[Id] <= 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @__lowerBound_0 int = 3;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @__lowerBound_0 AND [e].[Id] <= 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @__lowerBound_0 int = 3;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @__lowerBound_0 AND [e].[Id] <= 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
DECLARE @minCount int = 1;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE ([e].[Id] >= @minCount AND [e].[Id] <= 50) OR EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 10 AND [e0].[Id] <= 100 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
DECLARE @__minCount_0 int = 1;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE ([e].[Id] >= @__minCount_0 AND [e].[Id] <= 50) OR EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 10 AND [e0].[Id] <= 100 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
DECLARE @__minCount_0 int = 1;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE ([e].[Id] >= @__minCount_0 AND [e].[Id] <= 50) OR EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 10 AND [e0].[Id] <= 100 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DECLARE @lower int = 2;
DECLARE @upper int = 8;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @lower AND [e].[Id] <= @upper
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DECLARE @__lower_0 int = 2;
DECLARE @__upper_1 int = 8;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @__lower_0 AND [e].[Id] <= @__upper_1
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DECLARE @__lower_0 int = 2;
DECLARE @__upper_1 int = 8;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @__lower_0 AND [e].[Id] <= @__upper_1
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @targetName nvarchar(4000) = N'Alice';

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Name] = @targetName
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @__targetName_0 nvarchar(4000) = N'Alice';

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Name] = @__targetName_0
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @__targetName_0 nvarchar(4000) = N'Alice';

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Name] = @__targetName_0
Loading
Loading