Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 39 additions & 26 deletions dialects/cypher/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func NewAnalyzer() *Analyzer {
type variableBinding struct {
variable string // e.g., "u"
labels []string // e.g., ["User"]
optional bool // true if from OPTIONAL MATCH (properties can be null)
}

// queryContext holds context during query analysis.
Expand Down Expand Up @@ -137,15 +138,16 @@ func extractBindingsFromRegularQuery(rq *cyphergrammar.RegularQuery, ctx *queryC
for _, clause := range rq.SingleQuery.Clauses {
// Extract from MATCH clauses
if clause.Reading != nil && clause.Reading.Match != nil {
extractBindingsFromPattern(clause.Reading.Match.Pattern, ctx)
optional := clause.Reading.Match.Optional
extractBindingsFromPattern(clause.Reading.Match.Pattern, ctx, optional)
}
// Extract from CREATE clauses
if clause.Updating != nil && clause.Updating.Create != nil && clause.Updating.Create.Pattern != nil {
extractBindingsFromPattern(clause.Updating.Create.Pattern, ctx)
extractBindingsFromPattern(clause.Updating.Create.Pattern, ctx, false)
}
// Extract from MERGE clauses
if clause.Updating != nil && clause.Updating.Merge != nil && clause.Updating.Merge.Pattern != nil {
extractBindingsFromPatternElement(clause.Updating.Merge.Pattern.Element, ctx)
extractBindingsFromPatternElement(clause.Updating.Merge.Pattern.Element, ctx, false)
}
}

Expand All @@ -154,56 +156,57 @@ func extractBindingsFromRegularQuery(rq *cyphergrammar.RegularQuery, ctx *queryC
if union.Query != nil {
for _, clause := range union.Query.Clauses {
if clause.Reading != nil && clause.Reading.Match != nil {
extractBindingsFromPattern(clause.Reading.Match.Pattern, ctx)
optional := clause.Reading.Match.Optional
extractBindingsFromPattern(clause.Reading.Match.Pattern, ctx, optional)
}
if clause.Updating != nil && clause.Updating.Create != nil && clause.Updating.Create.Pattern != nil {
extractBindingsFromPattern(clause.Updating.Create.Pattern, ctx)
extractBindingsFromPattern(clause.Updating.Create.Pattern, ctx, false)
}
if clause.Updating != nil && clause.Updating.Merge != nil && clause.Updating.Merge.Pattern != nil {
extractBindingsFromPatternElement(clause.Updating.Merge.Pattern.Element, ctx)
extractBindingsFromPatternElement(clause.Updating.Merge.Pattern.Element, ctx, false)
}
}
}
}
}

func extractBindingsFromPattern(pattern *cyphergrammar.Pattern, ctx *queryContext) {
func extractBindingsFromPattern(pattern *cyphergrammar.Pattern, ctx *queryContext, optional bool) {
if pattern == nil {
return
}

for _, part := range pattern.Parts {
if part.Element != nil {
extractBindingsFromPatternElement(part.Element, ctx)
extractBindingsFromPatternElement(part.Element, ctx, optional)
}
}
}

func extractBindingsFromPatternElement(elem *cyphergrammar.PatternElement, ctx *queryContext) {
func extractBindingsFromPatternElement(elem *cyphergrammar.PatternElement, ctx *queryContext, optional bool) {
if elem == nil {
return
}

// Handle parenthesized pattern
if elem.Paren != nil {
extractBindingsFromPatternElement(elem.Paren, ctx)
extractBindingsFromPatternElement(elem.Paren, ctx, optional)
return
}

// Extract from node pattern
if elem.Node != nil {
extractNodeBinding(elem.Node, ctx)
extractNodeBinding(elem.Node, ctx, optional)
}

// Extract from chain
for _, chain := range elem.Chain {
if chain.Node != nil {
extractNodeBinding(chain.Node, ctx)
extractNodeBinding(chain.Node, ctx, optional)
}
}
}

func extractNodeBinding(node *cyphergrammar.NodePattern, ctx *queryContext) {
func extractNodeBinding(node *cyphergrammar.NodePattern, ctx *queryContext, optional bool) {
if node == nil || node.Variable == "" {
return
}
Expand All @@ -216,6 +219,7 @@ func extractNodeBinding(node *cyphergrammar.NodePattern, ctx *queryContext) {
ctx.bindings[node.Variable] = &variableBinding{
variable: node.Variable,
labels: labels,
optional: optional,
}
}

Expand Down Expand Up @@ -788,10 +792,18 @@ func extractProjectionItem(item *cyphergrammar.ProjectionItem, result *scaf.Quer
returnType := inferExpressionType(item.Expr, ctx)

// Get Required from schema for simple property access (e.g., "u.name")
// Default to required=true (non-nullable) unless schema says otherwise
required := true
if field := lookupFieldFromExpression(expression, ctx); field != nil {
required = field.Required
// Conservative default: assume nullable for complex expressions we can't analyze.
// Only use schema's Required for simple "variable.property" patterns where we can
// confidently determine nullability. This avoids generating non-pointer types for
// expressions that might return null (function calls, CASE, aggregates, etc.)
required := false
if field, binding := lookupFieldFromExpression(expression, ctx); field != nil {
// If binding is from OPTIONAL MATCH, always treat as nullable
if binding != nil && binding.optional {
required = false
} else {
required = field.Required
}
}

result.Returns = append(result.Returns, scaf.ReturnInfo{
Expand All @@ -809,16 +821,17 @@ func extractProjectionItem(item *cyphergrammar.ProjectionItem, result *scaf.Quer
}

// lookupFieldFromExpression extracts variable.property from expression and looks up the field.
// Returns nil if expression is not a simple property access or field not found.
func lookupFieldFromExpression(expression string, ctx *queryContext) *analysis.Field {
// Returns (nil, nil) if expression is not a simple property access or field not found.
// Returns the binding as well so caller can check if it's from OPTIONAL MATCH.
func lookupFieldFromExpression(expression string, ctx *queryContext) (*analysis.Field, *variableBinding) {
if ctx.schema == nil {
return nil
return nil, nil
}

// Parse "variable.property" pattern
parts := strings.SplitN(expression, ".", 2)
if len(parts) != 2 {
return nil
return nil, nil
}

varName := parts[0]
Expand All @@ -827,29 +840,29 @@ func lookupFieldFromExpression(expression string, ctx *queryContext) *analysis.F
// Check for additional operations (not a simple property access)
// e.g., "u.name IS NULL", "u.name + 'x'", "u.tags[0]"
if strings.ContainsAny(propName, " []()+<>=!") {
return nil
return nil, nil
}

// Look up the binding to get the model
binding, ok := ctx.bindings[varName]
if !ok || len(binding.labels) == 0 {
return nil
return nil, nil
}

modelName := binding.labels[0]
model, ok := ctx.schema.Models[modelName]
if !ok {
return nil
return nil, binding
}

// Find the field
for _, field := range model.Fields {
if field.Name == propName {
return field
return field, binding
}
}

return nil
return nil, binding
}

// expressionToString converts an Expression AST back to a string representation.
Expand Down
49 changes: 47 additions & 2 deletions dialects/cypher/analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,13 @@ func TestAnalyzer_RequiredField(t *testing.T) {
{Name: "tags", Type: analysis.SliceOf(analysis.TypeString), Required: false},
},
},
"Comment": {
Name: "Comment",
Fields: []*analysis.Field{
{Name: "id", Type: analysis.TypeString, Required: true},
{Name: "text", Type: analysis.TypeString, Required: true},
},
},
},
}

Expand Down Expand Up @@ -784,10 +791,48 @@ func TestAnalyzer_RequiredField(t *testing.T) {
wantRequired: false,
},
{
name: "complex expression defaults to required",
name: "complex expression defaults to nullable",
query: "MATCH (u:User) RETURN u.bio IS NULL",
wantType: "bool",
wantRequired: true, // Not a simple property access
wantRequired: false, // Not a simple property access - bail to nullable
},
// OPTIONAL MATCH tests
{
name: "optional match simple property",
query: "MATCH (c:Comment) OPTIONAL MATCH (u:User)-[:WROTE]->(c) RETURN u.id",
wantType: "string",
wantRequired: false, // OPTIONAL MATCH forces nullable
},
// Complex expression bailout tests - representative cases
{
name: "function call bails to nullable",
query: "MATCH (u:User) RETURN toUpper(u.id)",
wantType: "string",
wantRequired: false, // Function wrapping - can't trust schema
},
{
name: "case expression bails to nullable",
query: "MATCH (u:User) RETURN CASE WHEN u.age > 18 THEN 'adult' ELSE 'minor' END",
wantType: "string",
wantRequired: false, // CASE expression - can return null
},
{
name: "aggregate bails to nullable",
query: "MATCH (u:User) RETURN count(u)",
wantType: "int",
wantRequired: false, // Aggregate - bail to nullable (future: whitelist count as safe)
},
{
name: "list indexing bails to nullable",
query: "MATCH (u:User) RETURN u.tags[0]",
wantType: "string",
wantRequired: false, // Out of bounds returns null
},
{
name: "chained property access bails to nullable",
query: "MATCH (u:User) RETURN u.id.length",
wantType: "", // Unknown type for chained access
wantRequired: false, // Chained access - not simple var.prop
},
}

Expand Down
6 changes: 3 additions & 3 deletions language/go/production_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,13 +233,13 @@ import (
"github.com/example/db"
)

func CountUsers() []int {
func CountUsers() []*int {
return countUsersImpl()
}

var countUsersImpl func() []int = countUsersProd
var countUsersImpl func() []*int = countUsersProd

func countUsersProd() []int {
func countUsersProd() []*int {
return db.Query(ctx, query)
}
`
Expand Down
2 changes: 1 addition & 1 deletion language/go/signature_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ RETURN count(u) AS count
QueryName: "countUsers",
Params: []FuncParam{},
Returns: []FuncReturn{
{Name: "count", Type: "int", IsSlice: false},
{Name: "count", Type: "*int", IsSlice: false}, // Conservative: complex expression bails to nullable
},
},
},
Expand Down