Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@

### New Rules

Rule ID | Category | Severity | Notes
--------|----------|----------|---------------------------------------------------------------------------------------------
EFP0007 | Design | Warning | Projectable member accesses private or protected members of a non-partial containing type
8 changes: 8 additions & 0 deletions src/EntityFrameworkCore.Projectables.Generator/Diagnostics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,13 @@ public static class Diagnostics
DiagnosticSeverity.Error,
isEnabledByDefault: true);

public static readonly DiagnosticDescriptor InaccessibleMemberInNonPartialClass = new DiagnosticDescriptor(
id: "EFP0007",
title: "Projectable member accesses private or protected members",
messageFormat: "Projectable member '{0}' accesses private or protected members of '{1}'. Consider marking '{1}' and all its containing types as partial to allow the generated expression to access these members.",
category: "Design",
DiagnosticSeverity.Warning,
isEnabledByDefault: true);

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,18 @@ public class ProjectableDescriptor
public SyntaxList<TypeParameterConstraintClauseSyntax>? ConstraintClauses { get; set; }

public ExpressionSyntax? ExpressionBody { get; set; }

/// <summary>
/// Whether all containing types of the projectable member are declared as partial.
/// When true, the generated expression class will be nested inside the containing types
/// to allow access to private/protected members.
/// </summary>
public bool IsContainingClassPartial { get; set; }

/// <summary>
/// The chain of containing type declarations (from outermost to innermost / direct containing type).
/// Used when IsContainingClassPartial is true to generate the partial class wrappers.
/// </summary>
public IReadOnlyList<TypeDeclarationSyntax>? ContainingTypeChain { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,76 @@ static IEnumerable<string> GetNestedInClassPathForExtensionMember(ITypeSymbol ex
return [];
}

/// <summary>
/// Returns the chain of containing TypeDeclarationSyntax nodes (outermost first) for the member.
/// Returns an empty list if the member is not inside any type.
/// </summary>
static IReadOnlyList<TypeDeclarationSyntax> GetContainingTypeChain(MemberDeclarationSyntax member)
{
var result = new List<TypeDeclarationSyntax>();
var current = member.Parent;
while (current is TypeDeclarationSyntax typeDecl)
{
result.Insert(0, typeDecl);
current = current.Parent;
}
return result;
}

/// <summary>
/// Checks whether any identifier in the expression resolves to a private or protected member
/// of the given containing type (or one of its base types).
/// </summary>
static bool HasPrivateOrProtectedMemberAccess(
ExpressionSyntax expression,
INamedTypeSymbol containingType,
SemanticModel semanticModel)
{
foreach (var node in expression.DescendantNodesAndSelf())
{
// Skip lambda and anonymous function expressions themselves.
// Their symbols are compiler-synthesized private methods, not actual member accesses.
if (node is AnonymousFunctionExpressionSyntax)
continue;

var symbol = semanticModel.GetSymbolInfo(node).Symbol;
if (symbol is IFieldSymbol or IPropertySymbol or IMethodSymbol or IEventSymbol)
{
// Skip compiler-generated / implicitly declared symbols
if (symbol.IsImplicitlyDeclared)
continue;

// Only warn for accessibility levels that are NOT accessible from a standalone
// generated class in the same assembly:
// - Private: only within the declaring class → NOT accessible
// - Protected (ProtectedAndInternal = private protected): requires derived + same assembly → NOT accessible
// - Protected: requires derived class → NOT accessible
// Excluded: ProtectedOrInternal (protected internal) and Internal are accessible
// from the same assembly, so the generated class CAN access them without partial support.
if (symbol.DeclaredAccessibility is Accessibility.Private or Accessibility.Protected or Accessibility.ProtectedAndInternal)
{
// Check that the member belongs to the containing type (or a base type).
// Don't walk up to System.Object to avoid false positives from
// system-defined protected members (e.g., MemberwiseClone, Finalize).
var ownerType = symbol.ContainingType;
if (ownerType?.SpecialType == SpecialType.System_Object)
continue;

var current = (INamedTypeSymbol?)containingType;
while (current is not null && current.SpecialType != SpecialType.System_Object)
{
if (SymbolEqualityComparer.Default.Equals(current, ownerType))
{
return true;
}
current = current.BaseType;
}
}
}
}
return false;
}

public static ProjectableDescriptor? GetDescriptor(Compilation compilation, MemberDeclarationSyntax member, SourceProductionContext context)
{
var semanticModel = compilation.GetSemanticModel(member.SyntaxTree);
Expand Down Expand Up @@ -169,6 +239,13 @@ x is IPropertySymbol xProperty &&
extensionReceiverType = extensionParameter?.Type;
}

// Check if all containing types are partial (only for class members, not extension members)
var containingTypeChain = !isExtensionMember
? GetContainingTypeChain(member)
: (IReadOnlyList<TypeDeclarationSyntax>)Array.Empty<TypeDeclarationSyntax>();
var isContainingClassPartial = containingTypeChain.Count > 0 &&
containingTypeChain.All(t => t.Modifiers.Any(SyntaxKind.PartialKeyword));

// For extension members, use the extension receiver type for rewriting
var targetTypeForRewriting = isExtensionMember && extensionReceiverType is INamedTypeSymbol receiverNamedType
? receiverNamedType
Expand Down Expand Up @@ -329,6 +406,17 @@ x is IPropertySymbol xProperty &&
{
// Expression-bodied method (e.g., int Foo() => 1;)
bodyExpression = methodDeclarationSyntax.ExpressionBody.Expression;

// Warn if a private/protected member is accessed and the class is not partial
if (!isContainingClassPartial && !isExtensionMember &&
HasPrivateOrProtectedMemberAccess(bodyExpression, memberSymbol.ContainingType, semanticModel))
{
context.ReportDiagnostic(Diagnostic.Create(
Diagnostics.InaccessibleMemberInNonPartialClass,
methodDeclarationSyntax.GetLocation(),
memberSymbol.Name,
memberSymbol.ContainingType.Name));
}
}
else if (methodDeclarationSyntax.Body is not null)
{
Expand Down Expand Up @@ -400,6 +488,17 @@ x is IPropertySymbol xProperty &&
if (propertyDeclarationSyntax.ExpressionBody is not null)
{
bodyExpression = propertyDeclarationSyntax.ExpressionBody.Expression;

// Warn if a private/protected member is accessed and the class is not partial
if (!isContainingClassPartial && !isExtensionMember &&
HasPrivateOrProtectedMemberAccess(bodyExpression, memberSymbol.ContainingType, semanticModel))
{
context.ReportDiagnostic(Diagnostic.Create(
Diagnostics.InaccessibleMemberInNonPartialClass,
propertyDeclarationSyntax.GetLocation(),
memberSymbol.Name,
memberSymbol.ContainingType.Name));
}
}
else if (propertyDeclarationSyntax.AccessorList is not null)
{
Expand All @@ -411,6 +510,17 @@ x is IPropertySymbol xProperty &&
{
// get => expression;
bodyExpression = getter.ExpressionBody.Expression;

// Warn if a private/protected member is accessed and the class is not partial
if (!isContainingClassPartial && !isExtensionMember &&
HasPrivateOrProtectedMemberAccess(bodyExpression, memberSymbol.ContainingType, semanticModel))
{
context.ReportDiagnostic(Diagnostic.Create(
Diagnostics.InaccessibleMemberInNonPartialClass,
propertyDeclarationSyntax.GetLocation(),
memberSymbol.Name,
memberSymbol.ContainingType.Name));
}
}
else if (getter?.Body is not null)
{
Expand Down Expand Up @@ -457,6 +567,13 @@ x is IPropertySymbol xProperty &&
return null;
}

// Set partial class info if all containing types are partial
if (isContainingClassPartial)
{
descriptor.IsContainingClassPartial = true;
descriptor.ContainingTypeChain = containingTypeChain;
}

return descriptor;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,43 +66,34 @@ static void Execute(MemberDeclarationSyntax member, Compilation compilation, Sou
var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(projectable.ClassNamespace, projectable.NestedInClassNames, projectable.MemberName, projectable.ParameterTypeNames);
var generatedFileName = projectable.ClassTypeParameterList is not null ? $"{generatedClassName}-{projectable.ClassTypeParameterList.ChildNodes().Count()}.g.cs" : $"{generatedClassName}.g.cs";

var classSyntax = ClassDeclaration(generatedClassName)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.ClassTypeParameterList)
.WithConstraintClauses(projectable.ClassConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.AddAttributeLists(
AttributeList()
.AddAttributes(_editorBrowsableAttribute)
)
.AddMembers(
MethodDeclaration(
GenericName(
Identifier("global::System.Linq.Expressions.Expression"),
TypeArgumentList(
SingletonSeparatedList(
(TypeSyntax)GenericName(
Identifier("global::System.Func"),
GetLambdaTypeArgumentListSyntax(projectable)
)
)
)
),
"Expression"
)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.TypeParameterList)
.WithConstraintClauses(projectable.ConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.WithBody(
Block(
ReturnStatement(
ParenthesizedLambdaExpression(
projectable.ParametersList ?? ParameterList(),
null,
projectable.ExpressionBody
// Build the Expression method (shared between partial and non-partial generation)
var expressionMethod = MethodDeclaration(
GenericName(
Identifier("global::System.Linq.Expressions.Expression"),
TypeArgumentList(
SingletonSeparatedList(
(TypeSyntax)GenericName(
Identifier("global::System.Func"),
GetLambdaTypeArgumentListSyntax(projectable)
)
)
)
),
"Expression"
)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.TypeParameterList)
.WithConstraintClauses(projectable.ConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.WithBody(
Block(
ReturnStatement(
ParenthesizedLambdaExpression(
projectable.ParametersList ?? ParameterList(),
null,
projectable.ExpressionBody
)
)
)
);

#nullable disable
Expand All @@ -114,29 +105,84 @@ static void Execute(MemberDeclarationSyntax member, Compilation compilation, Sou
compilationUnit = compilationUnit.AddUsings(usingDirective);
}

if (projectable.ClassNamespace is not null)
if (projectable.IsContainingClassPartial && projectable.ContainingTypeChain is not null)
{
compilationUnit = compilationUnit.AddUsings(
UsingDirective(
ParseName(projectable.ClassNamespace)
// Generate the Expression class nested inside the partial containing class(es).
// This allows access to private/protected members of the containing class.
// The nested class does NOT redeclare the outer class's type parameters.
var nestedClassSyntax = ClassDeclaration(generatedClassName)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.AddAttributeLists(
AttributeList()
.AddAttributes(_editorBrowsableAttribute)
)
);
.AddMembers(expressionMethod);

// Wrap in the chain of partial class declarations (innermost first, then we reverse)
MemberDeclarationSyntax wrapped = nestedClassSyntax;
foreach (var typeDecl in projectable.ContainingTypeChain.Reverse())
{
wrapped = typeDecl
.WithAttributeLists(List<AttributeListSyntax>())
.WithBaseList(null)
.WithMembers(SingletonList<MemberDeclarationSyntax>(wrapped));
}

if (projectable.ClassNamespace is not null)
{
compilationUnit = compilationUnit.AddMembers(
NamespaceDeclaration(
ParseName(projectable.ClassNamespace)
).AddMembers(wrapped)
);
}
else
{
compilationUnit = compilationUnit.AddMembers(wrapped);
}
}
else
{
if (projectable.ClassNamespace is not null)
{
// Only add the class namespace using if not already present
var classNamespace = projectable.ClassNamespace;
if (!projectable.UsingDirectives.Any(u => u.Name?.ToString() == classNamespace))
{
compilationUnit = compilationUnit.AddUsings(
UsingDirective(
ParseName(classNamespace)
)
);
}
}

var classSyntax = ClassDeclaration(generatedClassName)
.WithModifiers(TokenList(Token(SyntaxKind.StaticKeyword)))
.WithTypeParameterList(projectable.ClassTypeParameterList)
.WithConstraintClauses(projectable.ClassConstraintClauses ?? List<TypeParameterConstraintClauseSyntax>())
.AddAttributeLists(
AttributeList()
.AddAttributes(_editorBrowsableAttribute)
)
.AddMembers(expressionMethod);

compilationUnit = compilationUnit
.AddMembers(
NamespaceDeclaration(
ParseName("EntityFrameworkCore.Projectables.Generated")
).AddMembers(classSyntax)
);
}

compilationUnit = compilationUnit
.AddMembers(
NamespaceDeclaration(
ParseName("EntityFrameworkCore.Projectables.Generated")
).AddMembers(classSyntax)
)
.WithLeadingTrivia(
TriviaList(
Comment("// <auto-generated/>"),
Trivia(NullableDirectiveTrivia(Token(SyntaxKind.DisableKeyword), true))
)
);


context.AddSource(generatedFileName, SourceText.From(compilationUnit.NormalizeWhitespace().ToFullString(), Encoding.UTF8));

static TypeArgumentListSyntax GetLambdaTypeArgumentListSyntax(ProjectableDescriptor projectable)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,18 @@ public LambdaExpression FindGeneratedExpression(MemberInfo projectableMemberInfo
.ToArray();
}

var generatedContainingTypeName = ProjectionExpressionClassNameGenerator.GenerateFullName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name, parameterTypeNames);
var generatedClassName = ProjectionExpressionClassNameGenerator.GenerateName(declaringType.Namespace, declaringType.GetNestedTypePath().Select(x => x.Name), projectableMemberInfo.Name, parameterTypeNames);
var generatedContainingTypeName = $"{ProjectionExpressionClassNameGenerator.Namespace}.{generatedClassName}";

var expressionFactoryType = declaringType.Assembly.GetType(generatedContainingTypeName);

if (expressionFactoryType is null)
{
// When the containing class is partial, the generated Expression class is a nested
// type inside the declaring type rather than in the Generated namespace.
expressionFactoryType = originalDeclaringType.GetNestedType(generatedClassName, BindingFlags.NonPublic | BindingFlags.Public);
}

if (expressionFactoryType is not null)
{
if (expressionFactoryType.IsGenericTypeDefinition)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
SELECT [i].[Value] * 2
FROM [InternalHelperEntity] AS [i]
Loading