diff --git a/src/openscad_parser/ast/builder.py b/src/openscad_parser/ast/builder.py index a91e10b..e525fc2 100644 --- a/src/openscad_parser/ast/builder.py +++ b/src/openscad_parser/ast/builder.py @@ -475,14 +475,16 @@ def visit_module_definition(self, node, children): # Flatten nested statement lists. Include all statements (assignments, # function/module declarations, module instantiations) so scope build # can hoist declarations and attach scopes to every node. + # Filter out None values that may result from visit_statement() returning + # None for statement nodes with no children. flattened = [] stack = list(statement) while stack: item = stack.pop(0) if isinstance(item, list): stack = item + stack - continue - flattened.append(item) + elif item is not None: + flattened.append(item) return ModuleDeclaration(name=name, parameters=parameters, children=flattened, position=self._get_node_position(node)) def visit_function_definition(self, node, children): diff --git a/src/openscad_parser/ast/nodes.py b/src/openscad_parser/ast/nodes.py index 3be521d..d7a012a 100644 --- a/src/openscad_parser/ast/nodes.py +++ b/src/openscad_parser/ast/nodes.py @@ -368,6 +368,11 @@ def __str__(self): def build_scope(self, parent_scope: "Scope") -> None: self.scope = parent_scope self.name.build_scope(parent_scope) + # Function literal bodies are closures that resolve variables lazily at + # call time, so the RHS always uses parent_scope (the full scope including + # the variable being assigned). This correctly handles recursive function + # literals and expressions containing function literals + # (e.g. `a = b ? function(x,n) a(...) : function(x,n) a(...)`). self.expr.build_scope(parent_scope) diff --git a/src/openscad_parser/ast/serialization.py b/src/openscad_parser/ast/serialization.py index e61836b..541c3e0 100644 --- a/src/openscad_parser/ast/serialization.py +++ b/src/openscad_parser/ast/serialization.py @@ -210,9 +210,10 @@ def _serialize_node(node: ASTNode, include_position: bool) -> dict[str, Any]: if include_position: result["_position"] = _serialize_position(node.position) - # Get all fields from the dataclass (excluding 'position' which we handle specially) + # Get all fields from the dataclass (excluding 'position' which we handle specially, + # and 'scope' which is runtime metadata not suitable for serialization) for field in dataclasses.fields(node): - if field.name == "position": + if field.name in ("position", "scope"): continue value = getattr(node, field.name) result[field.name] = _serialize_value(value, include_position) diff --git a/tests/test_scope.py b/tests/test_scope.py index 195065d..e247f86 100644 --- a/tests/test_scope.py +++ b/tests/test_scope.py @@ -498,6 +498,21 @@ def test_function_literal_in_expression(self): fl = pc.left # type: ignore assert fl.body.scope.lookup_variable("a") is not None # type: ignore + def test_function_literal_in_ternary_rhs_sees_assigned_variable(self): + """Function literals in a ternary RHS should see the variable being assigned.""" + ast = getASTfromString("a = b ? function(x, n) a(x + n, n - 1) : function(x, n) a(x * n, n - 1);") + assert ast is not None and isinstance(ast, list) + build_scopes(ast) + assignment = ast[0] + ternary = assignment.expr # type: ignore + true_fl = ternary.true_expr # type: ignore + false_fl = ternary.false_expr # type: ignore + assert isinstance(true_fl, FunctionLiteral) + assert isinstance(false_fl, FunctionLiteral) + # Both function bodies should see 'a' for recursion + assert true_fl.body.scope.lookup_variable("a") is not None # type: ignore + assert false_fl.body.scope.lookup_variable("a") is not None # type: ignore + class TestModularCallChildren: """Test that module call children get their own scope."""