From 2465ca1232f5cd3432e61066873ac4756aded0b1 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 17:10:55 -0400 Subject: [PATCH 01/19] Add forexp op and AST NodeTransformers --- effectful/internals/disassembler.py | 43 +++++++++++++++++++++++++++++ effectful/ops/syntax.py | 22 ++++++++++++++- 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index a06e004b..f547f158 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1075,3 +1075,46 @@ def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: """ assert inspect.isgenerator(genexpr), "Input must be a generator expression" return ast.fix_missing_locations(ast.Expression(ensure_ast(genexpr))) + + +class NameToCall(ast.NodeTransformer): + varnames: set[str] + + def __init__(self, varnames: set[str]): + self.varnames = varnames + + def visit_Name(self, node: ast.Name) -> ast.Call | ast.Name: + if node.id in self.varnames and isinstance(node.ctx, ast.Load): + return ast.Call(node, args=[], keywords=[]) + else: + return node + + +class GeneratorExpToForexpr(ast.NodeTransformer): + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> ast.Call: + if not all(isinstance(g.target, ast.Name) and not g.ifs for g in node.generators): + raise NotImplementedError("Generator expressions with unpacking and filters not yet implemented yet") + + streams = ast.Dict(keys=[], values=[]) + for gen in node.generators: + key: ast.Name = gen.target + value: ast.Lambda = ast.Lambda( + args=ast.arguments( + posonlyargs=[], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=NameToCall(set(v.id for v in streams.keys)).visit(gen.iter), + ) + streams.keys.append(key) + streams.values.append(value) + + body: ast.expr = NameToCall(set(v.id for v in streams.keys)).visit(node.elt) + + return ast.Call( + func=ast.Name(id='forexpr', ctx=ast.Load()), + args=[body, streams], + keywords=[] + ) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 47f06452..d2420294 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1,3 +1,4 @@ +import ast import collections.abc import dataclasses import functools @@ -5,7 +6,7 @@ import random import types import typing -from collections.abc import Callable +from collections.abc import Callable, Generator, Iterable, Mapping from typing import Annotated, Concatenate, Generic, TypeVar import tree @@ -977,6 +978,25 @@ def trace(value: Callable[P, T]) -> Callable[P, T]: return deffn(body, *bound_sig.args, **bound_sig.kwargs) +@defop +def forexpr( + body: Annotated[T, Scoped[A | B]], + streams: Annotated[Mapping[Operation[..., S], Callable[..., Iterable[S]]], Scoped[B]], +) -> Annotated[Iterable[T], Scoped[A]]: + """A higher-order operation that represents a for-expression.""" + raise NotImplementedError + + +@defterm.register(Generator) +def _(genexpr: Generator[T, None, None]) -> Expr[Iterable[T]]: + from effectful.internals.disassembler import GeneratorExpToForexpr, reconstruct + + genexpr_ast = reconstruct(genexpr) + forexpr_ast = GeneratorExpToForexpr().visit(genexpr_ast) + code = compile(ast.fix_missing_locations(forexpr_ast), "", "eval") + return eval(code, genexpr.gi_frame.f_globals, genexpr.gi_frame.f_locals) # type: ignore + + def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool: """Syntactic equality, ignoring the interpretation of the terms. From 112d41b4b386cbdd3fc971c2631ae07ef3c24fde Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 17:13:55 -0400 Subject: [PATCH 02/19] format --- effectful/internals/disassembler.py | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index f547f158..84fb5167 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1078,6 +1078,12 @@ def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: class NameToCall(ast.NodeTransformer): + """ + Transform variable names into calls to those variables. + This transformer replaces occurrences of specified variable names in an AST + with calls to those variables. For example, if the variable name is 'x', + it will replace 'x' with 'x()'. + """ varnames: set[str] def __init__(self, varnames: set[str]): @@ -1091,6 +1097,30 @@ def visit_Name(self, node: ast.Name) -> ast.Call | ast.Name: class GeneratorExpToForexpr(ast.NodeTransformer): + """ + Transform generator expressions into calls to `forexpr`. + This transformer converts generator expressions of the form: + + (expr for var in iter) + into calls to `forexpr`: + forexpr(lambda: expr, {var: lambda: iter}) + It supports multiple nested loops and ensures that variables are correctly + transformed into calls within the expression and iterators. + Note: This implementation currently does not support unpacking in loop variables + or filter conditions (ifs) in the generators. + Raises: + NotImplementedError: If the generator expression contains unpacking + in loop variables or filter conditions. + Example: + >>> import ast + >>> source = "(x * 2 for x in range(10))" + >>> tree = ast.parse(source, mode='eval') + >>> transformer = GeneratorExpToForexpr() + >>> transformed = transformer.visit(tree) + >>> ast.unparse(transformed) + 'forexpr(x() * 2, {x: lambda: range(10)})' + + """ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> ast.Call: if not all(isinstance(g.target, ast.Name) and not g.ifs for g in node.generators): raise NotImplementedError("Generator expressions with unpacking and filters not yet implemented yet") From 15c966125eb86175996933e3db9a66417677a392 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 17:20:19 -0400 Subject: [PATCH 03/19] name --- effectful/ops/syntax.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index d2420294..de8edbd4 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -993,8 +993,13 @@ def _(genexpr: Generator[T, None, None]) -> Expr[Iterable[T]]: genexpr_ast = reconstruct(genexpr) forexpr_ast = GeneratorExpToForexpr().visit(genexpr_ast) - code = compile(ast.fix_missing_locations(forexpr_ast), "", "eval") - return eval(code, genexpr.gi_frame.f_globals, genexpr.gi_frame.f_locals) # type: ignore + forexpr_name = ".".join(genexpr.gi_code.co_name.split(".")[:-1] + [""]) + forexpr_code = compile( + ast.fix_missing_locations(forexpr_ast), + filename=forexpr_name, + mode="eval", + ) + return eval(forexpr_code, genexpr.gi_frame.f_globals, genexpr.gi_frame.f_locals) # type: ignore def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool: From 9b8c2aa8489f9b5e2354b7aed42fa06674bd4bb4 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 17 Jun 2025 18:18:27 -0400 Subject: [PATCH 04/19] simplification --- effectful/internals/disassembler.py | 61 +++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 84fb5167..d9c01cc5 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1077,6 +1077,67 @@ def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: return ast.fix_missing_locations(ast.Expression(ensure_ast(genexpr))) +class IfsToIfAnd(ast.NodeTransformer): + """ + Transform multiple ifs in comprehensions into a single if with 'and'. + This transformer combines multiple filter conditions (ifs) in comprehension + generators into a single condition using logical 'and'. For example, it transforms: + [x for x in range(10) if x % 2 == 0 if x > 5] + into: + [x for x in range(10) if (x % 2 == 0 and x > 5)] + """ + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> ast.GeneratorExp: + return self._visit_compexp(node) + + def visit_DictComp(self, node: ast.DictComp) -> ast.DictComp: + return self._visit_compexp(node) + + def visit_SetComp(self, node: ast.SetComp) -> ast.SetComp: + return self._visit_compexp(node) + + def visit_ListComp(self, node: ast.ListComp) -> ast.ListComp: + return self._visit_compexp(node) + + @typing.overload + def _visit_compexp(self, node: ast.DictComp) -> ast.DictComp: ... + @typing.overload + def _visit_compexp(self, node: ast.GeneratorExp) -> ast.GeneratorExp: ... + @typing.overload + def _visit_compexp(self, node: ast.SetComp) -> ast.SetComp: ... + @typing.overload + def _visit_compexp(self, node: ast.ListComp) -> ast.ListComp: ... + + def _visit_compexp(self, node: CompExp): + new_generators = [] + for gen in node.generators: + if len(gen.ifs) > 1: + # Combine multiple ifs into a single if with 'and' + combined_if = gen.ifs[0] + for cond in gen.ifs[1:]: + combined_if = ast.BoolOp(op=ast.And(), values=[combined_if, cond]) + new_gen = ast.comprehension( + target=gen.target, + iter=gen.iter, + ifs=[combined_if], + is_async=gen.is_async + ) + new_generators.append(new_gen) + else: + new_generators.append(gen) + if isinstance(node, ast.DictComp): + return ast.DictComp( + key=node.key, + value=node.value, + generators=new_generators + ) + else: + return type(node)( + elt=node.elt, + generators=new_generators + ) + + class NameToCall(ast.NodeTransformer): """ Transform variable names into calls to those variables. From fcc23ba83e1cdc3b1feac40547b0fd80475f4a44 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 18 Jun 2025 19:15:11 -0400 Subject: [PATCH 05/19] Add unit tests and expand functionality of AST transformers --- effectful/internals/disassembler.py | 133 ++++-- tests/test_internals_disassembler.py | 687 +++++++++++++++++++++++++++ 2 files changed, 785 insertions(+), 35 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index cefc209f..8f76d9ba 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1319,22 +1319,17 @@ def _visit_compexp(self, node: CompExp): target=gen.target, iter=gen.iter, ifs=[combined_if], - is_async=gen.is_async + is_async=gen.is_async, ) new_generators.append(new_gen) else: new_generators.append(gen) if isinstance(node, ast.DictComp): return ast.DictComp( - key=node.key, - value=node.value, - generators=new_generators + key=node.key, value=node.value, generators=new_generators ) else: - return type(node)( - elt=node.elt, - generators=new_generators - ) + return type(node)(elt=node.elt, generators=new_generators) class NameToCall(ast.NodeTransformer): @@ -1344,6 +1339,7 @@ class NameToCall(ast.NodeTransformer): with calls to those variables. For example, if the variable name is 'x', it will replace 'x' with 'x()'. """ + varnames: set[str] def __init__(self, varnames: set[str]): @@ -1360,18 +1356,18 @@ class GeneratorExpToForexpr(ast.NodeTransformer): """ Transform generator expressions into calls to `forexpr`. This transformer converts generator expressions of the form: - + (expr for var in iter) into calls to `forexpr`: - forexpr(lambda: expr, {var: lambda: iter}) - It supports multiple nested loops and ensures that variables are correctly - transformed into calls within the expression and iterators. - Note: This implementation currently does not support unpacking in loop variables - or filter conditions (ifs) in the generators. - Raises: - NotImplementedError: If the generator expression contains unpacking - in loop variables or filter conditions. - Example: + forexpr(expr, {var: lambda: iter}) + + It supports: + - Multiple nested loops + - Filter conditions (if clauses) - converted to filtered generator expressions + - Tuple unpacking in loop variables + - Variables are correctly transformed into calls within the expression and iterators + + Examples: >>> import ast >>> source = "(x * 2 for x in range(10))" >>> tree = ast.parse(source, mode='eval') @@ -1380,31 +1376,98 @@ class GeneratorExpToForexpr(ast.NodeTransformer): >>> ast.unparse(transformed) 'forexpr(x() * 2, {x: lambda: range(10)})' + >>> source = "(x for x in range(10) if x % 2 == 0)" + >>> tree = ast.parse(source, mode='eval') + >>> transformed = transformer.visit(tree) + >>> ast.unparse(transformed) + 'forexpr(x(), {x: (x for x in range(10) if x % 2 == 0)})' + + >>> source = "((x, y) for x, y in pairs)" + >>> tree = ast.parse(source, mode='eval') + >>> transformed = transformer.visit(tree) + >>> ast.unparse(transformed) + 'forexpr((x(), y()), {(x, y): lambda: pairs})' + """ + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> ast.Call: - if not all(isinstance(g.target, ast.Name) and not g.ifs for g in node.generators): - raise NotImplementedError("Generator expressions with unpacking and filters not yet implemented yet") + # Check for unsupported features + for gen in node.generators: + if not isinstance(gen.target, ast.Name) and not isinstance( + gen.target, ast.Tuple + ): + raise NotImplementedError( + f"Unsupported target type: {type(gen.target)}" + ) + + # Get all variable names from all targets (including unpacked tuples) + def get_names_from_target(target): + if isinstance(target, ast.Name): + return [target.id] + elif isinstance(target, ast.Tuple): + names = [] + for elt in target.elts: + names.extend(get_names_from_target(elt)) + return names + else: + raise NotImplementedError( + f"Unsupported target type in unpacking: {type(target)}" + ) streams = ast.Dict(keys=[], values=[]) + all_var_names = set() + for gen in node.generators: - key: ast.Name = gen.target - value: ast.Lambda = ast.Lambda( - args=ast.arguments( - posonlyargs=[], - args=[], - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), - body=NameToCall(set(v.id for v in streams.keys)).visit(gen.iter), - ) - streams.keys.append(key) + # Collect variable names from previous generators + prev_var_names = set(all_var_names) + + # Add current target variables to the set + target_names = get_names_from_target(gen.target) + all_var_names.update(target_names) + + # Create the value for this generator + if gen.ifs: + # If there are filters, create a generator expression for the filtered iterator + # Note: In the filter conditions, we need to transform previous loop variables + # but NOT the current loop variable + filtered_gen = ast.GeneratorExp( + elt=gen.target if isinstance(gen.target, ast.Name) else gen.target, + generators=[ + ast.comprehension( + target=gen.target, + iter=self.visit(NameToCall(prev_var_names).visit(gen.iter)), + ifs=[ + self.visit(NameToCall(prev_var_names).visit(if_clause)) + for if_clause in gen.ifs + ], + is_async=gen.is_async, + ) + ], + ) + value = filtered_gen + else: + # No filters, create a lambda + value = ast.Lambda( + args=ast.arguments( + posonlyargs=[], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=self.visit(NameToCall(prev_var_names).visit(gen.iter)), + ) + + streams.keys.append(gen.target) streams.values.append(value) - body: ast.expr = NameToCall(set(v.id for v in streams.keys)).visit(node.elt) + # Transform the body expression + # First apply NameToCall, then recursively visit for nested generators + body: ast.expr = NameToCall(all_var_names).visit(node.elt) + body = self.visit(body) # Recursively transform nested generator expressions return ast.Call( - func=ast.Name(id='forexpr', ctx=ast.Load()), + func=ast.Name(id="forexpr", ctx=ast.Load()), args=[body, streams], - keywords=[] + keywords=[], ) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index dd5e2c2e..38514b26 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -650,3 +650,690 @@ def test_error_handling(): list(gen) # Consume it with pytest.raises(AssertionError): reconstruct(gen) + + +# ============================================================================ +# AST TRANSFORMER TESTS +# ============================================================================ + + +@pytest.mark.parametrize( + "genexpr_src,expected_src", + [ + # Generator expressions with multiple ifs + ( + "(x for x in range(10) if x % 2 == 0 if x > 5)", + "(x for x in range(10) if x % 2 == 0 and x > 5)", + ), + ( + "(x for x in range(20) if x > 5 if x < 15 if x % 2 == 0)", + "(x for x in range(20) if (x > 5 and x < 15) and x % 2 == 0)", + ), + # List comprehensions with multiple ifs + ( + "[x for x in range(10) if x % 2 == 0 if x > 5]", + "[x for x in range(10) if x % 2 == 0 and x > 5]", + ), + ( + "[x * 2 for x in range(100) if x > 10 if x < 90 if x % 5 == 0]", + "[x * 2 for x in range(100) if (x > 10 and x < 90) and x % 5 == 0]", + ), + # Set comprehensions with multiple ifs + ( + "{x for x in range(10) if x % 2 == 0 if x > 5}", + "{x for x in range(10) if x % 2 == 0 and x > 5}", + ), + # Dict comprehensions with multiple ifs + ( + "{x: x**2 for x in range(10) if x % 2 == 0 if x > 5}", + "{x: x**2 for x in range(10) if x % 2 == 0 and x > 5}", + ), + # Single if should remain unchanged + ( + "(x for x in range(10) if x % 2 == 0)", + "(x for x in range(10) if x % 2 == 0)", + ), + # No ifs should remain unchanged + ( + "(x for x in range(10))", + "(x for x in range(10))", + ), + # Multiple generators with multiple ifs each + ( + "(x + y for x in range(5) if x % 2 == 0 if x > 1 for y in range(5) if y % 2 == 1 if y < 4)", + "(x + y for x in range(5) if x % 2 == 0 and x > 1 for y in range(5) if y % 2 == 1 and y < 4)", + ), + # Complex expressions with multiple ifs + ( + "((x, y) for x in range(10) if x > 2 if x < 8 for y in range(10) if y != x if y % 2 == 0)", + "((x, y) for x in range(10) if x > 2 and x < 8 for y in range(10) if y != x and y % 2 == 0)", + ), + ], +) +def test_ifs_to_if_and_transformer(genexpr_src, expected_src): + """Test IfsToIfAnd transformer converts multiple ifs to single if with and.""" + from effectful.internals.disassembler import IfsToIfAnd + + # Parse the source code + source_ast = ast.parse(genexpr_src, mode="eval") + expected_ast = ast.parse(expected_src, mode="eval") + + # Apply the transformer + transformer = IfsToIfAnd() + transformed_ast = transformer.visit(source_ast) + + # Compare the unparsed strings + transformed_src = ast.unparse(transformed_ast) + expected_unparsed = ast.unparse(expected_ast) + + assert transformed_src == expected_unparsed, ( + f"Transformation failed:\n" + f"Input: {genexpr_src}\n" + f"Expected: {expected_unparsed}\n" + f"Got: {transformed_src}" + ) + + +@pytest.mark.parametrize( + "genexpr_src,expected_src", + [ + # Nested comprehensions in body + ( + "([x for x in range(i)] for i in range(5))", + "([x for x in range(i)] for i in range(5))", + ), + ( + "([x * 2 for x in range(i) if x % 2 == 0] for i in range(5))", + "([x * 2 for x in range(i) if x % 2 == 0] for i in range(5))", + ), + ( + "({x: x**2 for x in range(i)} for i in range(5))", + "({x: x**2 for x in range(i)} for i in range(5))", + ), + ( + "({x for x in range(i) if x > 2} for i in range(10))", + "({x for x in range(i) if x > 2} for i in range(10))", + ), + # Double nested comprehensions + ( + "([[x for x in range(j)] for j in range(i)] for i in range(3))", + "([[x for x in range(j)] for j in range(i)] for i in range(3))", + ), + ( + "([[x + y for y in range(j)] for j in range(i)] for i in range(3))", + "([[x + y for y in range(j)] for j in range(i)] for i in range(3))", + ), + # Comprehensions with multiple filters that should be combined + ( + "(x for x in range(100) if x > 10 if x < 90 if x % 5 == 0)", + "(x for x in range(100) if (x > 10 and x < 90) and x % 5 == 0)", + ), + ( + "[x**2 for x in range(50) if x % 2 == 0 if x % 3 == 0 if x > 10]", + "[x**2 for x in range(50) if (x % 2 == 0 and x % 3 == 0) and x > 10]", + ), + # Complex nested loops with multiple filters + ( + "((x, y) for x in range(10) if x > 2 if x < 8 for y in range(10) if y != x if y % 2 == 0)", + "((x, y) for x in range(10) if x > 2 and x < 8 for y in range(10) if y != x and y % 2 == 0)", + ), + ( + "[x + y + z for x in range(3) if x > 0 for y in range(3) if y != x for z in range(3) if z != x if z != y]", + "[x + y + z for x in range(3) if x > 0 for y in range(3) if y != x for z in range(3) if (z != x and z != y)]", + ), + # Generator expressions with function calls in body + ( + "(sum([x for x in range(i)]) for i in range(5))", + "(sum([x for x in range(i)]) for i in range(5))", + ), + ( + "(max(x for x in range(i + 1)) for i in range(3))", + "(max(x for x in range(i + 1)) for i in range(3))", + ), + # Nested generator in iterator + ( + "([y for y in range(x)] for x in (x_ + 1 for x_ in range(5)))", + "([y for y in range(x)] for x in (x_ + 1 for x_ in range(5)))", + ), + ( + "(x * 2 for x in [y**2 for y in range(5)])", + "(x * 2 for x in [y**2 for y in range(5)])", + ), + # Mixed comprehension types + ( + "({x: [y for y in range(x)] for x in range(5)})", + "({x: [y for y in range(x)] for x in range(5)})", + ), + ( + "[{x: x**2 for x in range(i)} for i in range(5)]", + "[{x: x**2 for x in range(i)} for i in range(5)]", + ), + # Comprehensions with nested filters + ( + "([x for x in range(i) if x < i] for i in range(5) if i > 0)", + "([x for x in range(i) if x < i] for i in range(5) if i > 0)", + ), + ( + "([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5))", + "([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5))", + ), + ], +) +def test_ifs_to_if_and_nested_comprehensions(genexpr_src, expected_src): + """Test IfsToIfAnd transformer with nested comprehensions.""" + from effectful.internals.disassembler import IfsToIfAnd + + # Parse the source code + source_ast = ast.parse(genexpr_src, mode="eval") + expected_ast = ast.parse(expected_src, mode="eval") + + # Apply the transformer + transformer = IfsToIfAnd() + transformed_ast = transformer.visit(source_ast) + + # Compare the unparsed strings + transformed_src = ast.unparse(transformed_ast) + expected_unparsed = ast.unparse(expected_ast) + + assert transformed_src == expected_unparsed, ( + f"Transformation failed:\n" + f"Input: {genexpr_src}\n" + f"Expected: {expected_unparsed}\n" + f"Got: {transformed_src}" + ) + + +@pytest.mark.parametrize( + "source_src,varnames,expected_src", + [ + # Simple name replacement + ("x", {"x"}, "x()"), + ("x + y", {"x"}, "x() + y"), + ("x + y", {"y"}, "x + y()"), + ("x + y", {"x", "y"}, "x() + y()"), + # Names in different contexts (Note: Store/Del contexts work differently) + # These would need to be parsed as statements, not expressions + # Skipping assignment and del tests as they can't be parsed in eval mode + # Complex expressions + ("x * 2 + y / 3", {"x", "y"}, "x() * 2 + y() / 3"), + ("func(x, y, z)", {"x", "z"}, "func(x(), y, z())"), + # Nested expressions + ("x + (y * z)", {"x", "y", "z"}, "x() + (y() * z())"), + ("[x, y, z]", {"x", "y"}, "[x(), y(), z]"), + ("{'a': x, 'b': y}", {"x", "y"}, "{'a': x(), 'b': y()}"), + # Attribute access + ("x.attr", {"x"}, "x().attr"), + ("x.method()", {"x"}, "x().method()"), + ("obj.x", {"x"}, "obj.x"), # x is an attribute, not a variable + # Subscript operations + ("x[0]", {"x"}, "x()[0]"), + ("arr[x]", {"x"}, "arr[x()]"), + ("x[y]", {"x", "y"}, "x()[y()]"), + # Function calls + ("f(x)", {"f"}, "f()(x)"), + ("f(x)", {"x"}, "f(x())"), + ("f(x, y)", {"f", "x", "y"}, "f()(x(), y())"), + # Lambda expressions + ("lambda a: x + a", {"x"}, "lambda a: x() + a"), + ( + "lambda x: x + y", + {"x", "y"}, + "lambda x: x() + y()", + ), # Transformer doesn't check scope + # Comprehensions + ("[x for i in range(3)]", {"x"}, "[x() for i in range(3)]"), + ("[i for i in x]", {"x"}, "[i for i in x()]"), + ( + "[x for x in range(3)]", + {"x"}, + "[x() for x in range(3)]", + ), # Transformer doesn't check scope + ("(x + i for i in range(3))", {"x"}, "(x() + i for i in range(3))"), + # Multiple occurrences + ("x + x * x", {"x"}, "x() + x() * x()"), + # Names not in the set should not be transformed + ("x + y + z", {"x"}, "x() + y + z"), + ("a + b + c", {"x", "y", "z"}, "a + b + c"), + ], +) +def test_name_to_call_transformer(source_src, varnames, expected_src): + """Test NameToCall transformer converts specified names to function calls.""" + from effectful.internals.disassembler import NameToCall + + # Parse the source code + source_ast = ast.parse(source_src, mode="eval") + expected_ast = ast.parse(expected_src, mode="eval") + + # Apply the transformer + transformer = NameToCall(varnames) + transformed_ast = transformer.visit(source_ast) + + # Compare the unparsed strings + transformed_src = ast.unparse(transformed_ast) + expected_unparsed = ast.unparse(expected_ast) + + assert transformed_src == expected_unparsed, ( + f"Transformation failed:\n" + f"Input: {source_src}\n" + f"Varnames: {varnames}\n" + f"Expected: {expected_unparsed}\n" + f"Got: {transformed_src}" + ) + + +@pytest.mark.parametrize( + "source_src,varnames,expected_src", + [ + # Nested comprehensions in body + ("[x for x in range(i)]", {"i"}, "[x for x in range(i())]"), + ( + "[[x for x in range(j)] for j in range(i)]", + {"i"}, + "[[x for x in range(j)] for j in range(i())]", + ), + ( + "[[x for x in range(j)] for j in range(i)]", + {"i", "j"}, + "[[x for x in range(j())] for j in range(i())]", + ), + # Complex nested comprehensions + ("sum(x for x in range(i))", {"i"}, "sum((x for x in range(i())))"), + ("sum(x for x in range(i))", {"x"}, "sum((x() for x in range(i)))"), + ("sum(x for x in range(i))", {"i", "x"}, "sum((x() for x in range(i())))"), + # Names in nested comprehensions + ("[y for y in range(x)]", {"x"}, "[y for y in range(x())]"), + ("[y for y in range(x)]", {"y"}, "[y() for y in range(x)]"), + ( + "[[y + z for y in range(x)] for z in range(x)]", + {"x"}, + "[[y + z for y in range(x())] for z in range(x())]", + ), + ( + "[[y + z for y in range(x)] for z in range(x)]", + {"x", "y", "z"}, + "[[y() + z() for y in range(x())] for z in range(x())]", + ), + # Mixed comprehension types with names + ( + "{x: [y for y in range(x)] for x in range(n)}", + {"n"}, + "{x: [y for y in range(x)] for x in range(n())}", + ), + ( + "{i: [j for j in range(i)] for i in range(n)}", + {"i", "j", "n"}, + "{i(): [j() for j in range(i())] for i in range(n())}", + ), + # Names in different parts of comprehensions + ("[f(x) for x in data]", {"f"}, "[f()(x) for x in data]"), + ("[f(x) for x in data]", {"data"}, "[f(x) for x in data()]"), + ("[f(x) for x in data]", {"f", "data"}, "[f()(x) for x in data()]"), + ( + "[f(x) for x in data if pred(x)]", + {"f", "pred", "data"}, + "[f()(x) for x in data() if pred()(x)]", + ), + # Nested function calls + ("f(g(x))", {"f"}, "f()(g(x))"), + ("f(g(x))", {"g"}, "f(g()(x))"), + ("f(g(x))", {"x"}, "f(g(x()))"), + ("f(g(h(x)))", {"f", "g", "h", "x"}, "f()(g()(h()(x())))"), + # Complex expressions with comprehensions + ( + "sum([x * y for x in range(a) for y in range(b)])", + {"a", "b"}, + "sum([x * y for x in range(a()) for y in range(b())])", + ), + ( + "max(x + y for x in items1 for y in items2)", + {"items1", "items2"}, + "max((x + y for x in items1() for y in items2()))", + ), + # Boolean operations with names + ("x and y", {"x"}, "x() and y"), + ("x and y", {"x", "y"}, "x() and y()"), + ("x or y or z", {"x", "z"}, "x() or y or z()"), + ("not x", {"x"}, "not x()"), + # Ternary expressions + ("x if cond else y", {"cond"}, "x if cond() else y"), + ("x if cond else y", {"x", "y", "cond"}, "x() if cond() else y()"), + # Names in slice operations + ("arr[start:end]", {"start", "end"}, "arr[start():end()]"), + ("arr[i:j:k]", {"i", "j", "k"}, "arr[i():j():k()]"), + ("matrix[i][j]", {"i", "j"}, "matrix[i()][j()]"), + # Dict/set comprehensions + ("{x: y for x, y in pairs}", {"pairs"}, "{x: y for (x, y) in pairs()}"), + ("{f(x) for x in items}", {"f", "items"}, "{f()(x) for x in items()}"), + # Nested comprehensions with filters + ( + "[x for x in [y for y in range(n) if y > m] if x < k]", + {"n", "m", "k"}, + "[x for x in [y for y in range(n()) if y > m()] if x < k()]", + ), + ], +) +def test_name_to_call_nested_comprehensions(source_src, varnames, expected_src): + """Test NameToCall transformer with nested comprehensions and complex expressions.""" + from effectful.internals.disassembler import NameToCall + + # Parse the source code + source_ast = ast.parse(source_src, mode="eval") + expected_ast = ast.parse(expected_src, mode="eval") + + # Apply the transformer + transformer = NameToCall(varnames) + transformed_ast = transformer.visit(source_ast) + + # Compare the unparsed strings + transformed_src = ast.unparse(transformed_ast) + expected_unparsed = ast.unparse(expected_ast) + + assert transformed_src == expected_unparsed, ( + f"Transformation failed:\n" + f"Input: {source_src}\n" + f"Varnames: {varnames}\n" + f"Expected: {expected_unparsed}\n" + f"Got: {transformed_src}" + ) + + +@pytest.mark.parametrize( + "genexpr_src,expected_src", + [ + # Simple generator expressions + ( + "(x for x in range(10))", + "forexpr(x(), {x: lambda: range(10)})", + ), + ( + "(x * 2 for x in range(10))", + "forexpr(x() * 2, {x: lambda: range(10)})", + ), + ( + "(x + 1 for x in items)", + "forexpr(x() + 1, {x: lambda: items})", + ), + # Complex expressions + ( + "(x ** 2 + 2 * x + 1 for x in range(5))", + "forexpr(x() ** 2 + 2 * x() + 1, {x: lambda: range(5)})", + ), + ( + "(f(x) for x in data)", + "forexpr(f(x()), {x: lambda: data})", + ), + # Multiple nested loops + ( + "(x + y for x in range(3) for y in range(4))", + "forexpr(x() + y(), {x: lambda: range(3), y: lambda: range(4)})", + ), + ( + "(x * y for x in items1 for y in items2)", + "forexpr(x() * y(), {x: lambda: items1, y: lambda: items2})", + ), + # Nested loops with dependencies + ( + "(x + y for x in range(3) for y in range(x))", + "forexpr(x() + y(), {x: lambda: range(3), y: lambda: range(x())})", + ), + ( + "((x, y) for x in range(3) for y in range(x, 5))", + "forexpr((x(), y()), {x: lambda: range(3), y: lambda: range(x(), 5)})", + ), + # Triple nested loops + ( + "(x + y + z for x in range(2) for y in range(2) for z in range(2))", + "forexpr(x() + y() + z(), {x: lambda: range(2), y: lambda: range(2), z: lambda: range(2)})", + ), + # Complex iterators + ( + "(x for x in [1, 2, 3])", + "forexpr(x(), {x: lambda: [1, 2, 3]})", + ), + ( + "(x for x in list(range(5)))", + "forexpr(x(), {x: lambda: list(range(5))})", + ), + # Expressions with function calls on iterators + ( + "(x for x in sorted(items))", + "forexpr(x(), {x: lambda: sorted(items)})", + ), + # Generator expressions with filters + ( + "(x for x in range(10) if x % 2 == 0)", + "forexpr(x(), {x: (x for x in range(10) if x % 2 == 0)})", + ), + ( + "(x + y for x in range(3) if x > 0 for y in range(3))", + "forexpr(x() + y(), {x: (x for x in range(3) if x > 0), y: lambda: range(3)})", + ), + ( + "(x * 2 for x in items if x > 5)", + "forexpr(x() * 2, {x: (x for x in items if x > 5)})", + ), + ( + "(x + y for x in range(5) for y in range(5) if x < y)", + "forexpr(x() + y(), {x: lambda: range(5), y: (y for y in range(5) if x() < y)})", + ), + ( + "(x for x in range(20) if x % 2 == 0 if x % 3 == 0)", + "forexpr(x(), {x: (x for x in range(20) if x % 2 == 0 if x % 3 == 0)})", + ), + # Generator expressions with unpacking + ( + "((x, y) for x, y in pairs)", + "forexpr((x(), y()), {(x, y): lambda: pairs})", + ), + ( + "(a + b for a, b in zip(list1, list2))", + "forexpr(a() + b(), {(a, b): lambda: zip(list1, list2)})", + ), + ( + "(x + y + z for x, (y, z) in nested_pairs)", + "forexpr(x() + y() + z(), {(x, (y, z)): lambda: nested_pairs})", + ), + ], +) +def test_generator_exp_to_forexpr_transformer(genexpr_src, expected_src): + """Test GeneratorExpToForexpr transformer converts generator expressions to forexpr calls.""" + from effectful.internals.disassembler import GeneratorExpToForexpr + + # Parse the source code + source_ast = ast.parse(genexpr_src, mode="eval") + + # Apply the transformer + transformer = GeneratorExpToForexpr() + + # For test cases that should raise NotImplementedError + if expected_src == "NOT_APPLICABLE": + # The xfail marker will handle the exception + transformed_ast = transformer.visit(source_ast) + return + + transformed_ast = transformer.visit(source_ast) + expected_ast = ast.parse(expected_src, mode="eval") + + # Compare the unparsed strings + transformed_src = ast.unparse(transformed_ast) + expected_unparsed = ast.unparse(expected_ast) + + assert transformed_src == expected_unparsed, ( + f"Transformation failed:\n" + f"Input: {genexpr_src}\n" + f"Expected: {expected_unparsed}\n" + f"Got: {transformed_src}" + ) + + +@pytest.mark.parametrize( + "genexpr_src,expected_src", + [ + # Generator expressions yielding comprehensions + ( + "([x for x in range(i)] for i in range(5))", + "forexpr([x for x in range(i())], {i: lambda: range(5)})", + ), + ( + "({x: x**2 for x in range(i)} for i in range(5))", + "forexpr({x: x**2 for x in range(i())}, {i: lambda: range(5)})", + ), + ( + "({x for x in range(i) if x > 2} for i in range(10))", + "forexpr({x for x in range(i()) if x > 2}, {i: lambda: range(10)})", + ), + # Double nested comprehensions + ( + "([[x for x in range(j)] for j in range(i)] for i in range(3))", + "forexpr([[x for x in range(j)] for j in range(i())], {i: lambda: range(3)})", + ), + ( + "([[x + y for y in range(j)] for j in range(i)] for i in range(3))", + "forexpr([[x + y for y in range(j)] for j in range(i())], {i: lambda: range(3)})", + ), + # Generator with comprehension in iterator + ( + "(x * 2 for x in [y**2 for y in range(5)])", + "forexpr(x() * 2, {x: lambda: [y**2 for y in range(5)]})", + ), + ( + "(x for x in {y: y**2 for y in range(3)}.values())", + "forexpr(x(), {x: lambda: {y: y**2 for y in range(3)}.values()})", + ), + # Complex expressions with nested calls + ( + "(sum([x for x in range(i)]) for i in range(5))", + "forexpr(sum([x for x in range(i())]), {i: lambda: range(5)})", + ), + ( + "(max(x for x in range(i + 1)) for i in range(3))", + "forexpr(max(forexpr(x(), {x: lambda: range(i() + 1)})), {i: lambda: range(3)})", + ), + ( + "(list(enumerate(x + 1 for x in range(i + 1))) for i in range(3))", + "forexpr(list(enumerate(forexpr(x() + 1, {x: lambda: range(i() + 1)}))), {i: lambda: range(3)})", + ), + # Nested generators in iterator + ( + "([y for y in range(x)] for x in (z + 1 for z in range(5)))", + "forexpr([y for y in range(x())], {x: lambda: forexpr(z() + 1, {z: lambda: range(5)})})", + ), + # Complex filters with nested comprehensions + ( + "(x for x in range(10) if x in [y**2 for y in range(5)])", + "forexpr(x(), {x: (x for x in range(10) if x in [y**2 for y in range(5)])})", + ), + ( + "([x, y] for x in range(3) if x > 0 for y in range(3) if y in [z for z in range(x)])", + "forexpr([x(), y()], {x: (x for x in range(3) if x > 0), y: (y for y in range(3) if y in [z for z in range(x())])})", + ), + # Multiple filters with dependencies + ( + "(x + y for x in range(5) if x > 1 for y in range(x) if y < x - 1)", + "forexpr(x() + y(), {x: (x for x in range(5) if x > 1), y: (y for y in range(x()) if y < x() - 1)})", + ), + ( + "((x, y, z) for x in range(3) for y in range(x, 5) if y > x for z in range(y) if z < y)", + "forexpr((x(), y(), z()), {x: lambda: range(3), y: (y for y in range(x(), 5) if y > x()), z: (z for z in range(y()) if z < y())})", + ), + # Unpacking with nested structures + ( + "(a + b + c for (a, b), c in [((1, 2), 3), ((4, 5), 6)])", + "forexpr(a() + b() + c(), {((a, b), c): lambda: [((1, 2), 3), ((4, 5), 6)]})", + ), + ( + "(x + sum(lst) for x, lst in [(1, [2, 3]), (4, [5, 6])])", + "forexpr(x() + sum(lst()), {(x, lst): lambda: [(1, [2, 3]), (4, [5, 6])]})", + ), + # Complex iterators + ( + "(x for x in sorted([y**2 for y in range(5)]))", + "forexpr(x(), {x: lambda: sorted([y**2 for y in range(5)])})", + ), + ( + "(item for sublist in [[1, 2], [3, 4], [5, 6]] for item in sublist)", + "forexpr(item(), {sublist: lambda: [[1, 2], [3, 4], [5, 6]], item: lambda: sublist()})", + ), + # Expressions with method calls + ( + "(s.upper() for s in ['hello', 'world'] if s.startswith('h'))", + "forexpr(s().upper(), {s: (s for s in ['hello', 'world'] if s.startswith('h'))})", + ), + ( + "(obj.value for obj in objects if hasattr(obj, 'value'))", + "forexpr(obj().value, {obj: (obj for obj in objects if hasattr(obj, 'value'))})", + ), + # CRITICAL: Generator expressions yielding generator expressions + ( + "((x for x in range(i)) for i in range(5))", + "forexpr(forexpr(x(), {x: lambda: range(i())}), {i: lambda: range(5)})", + ), + ( + "((x * 2 for x in range(i)) for i in range(3))", + "forexpr(forexpr(x() * 2, {x: lambda: range(i())}), {i: lambda: range(3)})", + ), + ( + "((x + y for x in range(3) for y in range(x)) for i in range(2))", + "forexpr(forexpr(x() + y(), {x: lambda: range(3), y: lambda: range(x())}), {i: lambda: range(2)})", + ), + # Generator yielding filtered generator + ( + "((x for x in range(10) if x % 2 == 0) for i in range(3))", + "forexpr(forexpr(x(), {x: (x for x in range(10) if x % 2 == 0)}), {i: lambda: range(3)})", + ), + ( + "((x for x in range(i) if x > 0) for i in range(5))", + "forexpr(forexpr(x(), {x: (x for x in range(i()) if x > 0)}), {i: lambda: range(5)})", + ), + # Nested generators with multiple levels + ( + "((y for y in (x for x in range(i))) for i in range(3))", + "forexpr(forexpr(y(), {y: lambda: forexpr(x(), {x: lambda: range(i())})}), {i: lambda: range(3)})", + ), + ( + "(((x + y for x in range(2)) for y in range(3)) for z in range(4))", + "forexpr(forexpr(forexpr(x() + y(), {x: lambda: range(2)}), {y: lambda: range(3)}), {z: lambda: range(4)})", + ), + # Generator with unpacking yielding generator + ( + "((x + b for x in range(a)) for a, b in [(2, 3), (4, 5)])", + "forexpr(forexpr(x() + b(), {x: lambda: range(a())}), {(a, b): lambda: [(2, 3), (4, 5)]})", + ), + # Complex case: generator yielding generator with filters and dependencies + ( + "((x + y for x in range(i) if x > 0 for y in range(x)) for i in range(5) if i > 2)", + "forexpr(forexpr(x() + y(), {x: (x for x in range(i()) if x > 0), y: lambda: range(x())}), {i: (i for i in range(5) if i > 2)})", + ), + # Generator expression yielding sum of generator expression + ( + "(sum(x for x in range(i)) for i in range(5))", + "forexpr(sum(forexpr(x(), {x: lambda: range(i())})), {i: lambda: range(5)})", + ), + ( + "(max(x * 2 for x in range(i) if x > 0) for i in range(10))", + "forexpr(max(forexpr(x() * 2, {x: (x for x in range(i()) if x > 0)})), {i: lambda: range(10)})", + ), + ], +) +def test_generator_exp_to_forexpr_nested_comprehensions(genexpr_src, expected_src): + """Test GeneratorExpToForexpr transformer with nested comprehensions and complex expressions.""" + from effectful.internals.disassembler import GeneratorExpToForexpr + + # Parse the source code + source_ast = ast.parse(genexpr_src, mode="eval") + + # Apply the transformer + transformer = GeneratorExpToForexpr() + transformed_ast = transformer.visit(source_ast) + expected_ast = ast.parse(expected_src, mode="eval") + + # Compare the unparsed strings + transformed_src = ast.unparse(transformed_ast) + expected_unparsed = ast.unparse(expected_ast) + + assert transformed_src == expected_unparsed, ( + f"Transformation failed:\n" + f"Input: {genexpr_src}\n" + f"Expected: {expected_unparsed}\n" + f"Got: {transformed_src}" + ) From ea5c2fd41138617c52815cc17c8cebbe07386d5e Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 18 Jun 2025 19:19:25 -0400 Subject: [PATCH 06/19] format --- effectful/ops/syntax.py | 4 +++- tests/test_internals_disassembler.py | 22 +++++++--------------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index de8edbd4..57643d34 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -981,7 +981,9 @@ def trace(value: Callable[P, T]) -> Callable[P, T]: @defop def forexpr( body: Annotated[T, Scoped[A | B]], - streams: Annotated[Mapping[Operation[..., S], Callable[..., Iterable[S]]], Scoped[B]], + streams: Annotated[ + Mapping[Operation[..., S], Callable[..., Iterable[S]]], Scoped[B] + ], ) -> Annotated[Iterable[T], Scoped[A]]: """A higher-order operation that represents a for-expression.""" raise NotImplementedError diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 38514b26..e02de2ae 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -4,7 +4,13 @@ import pytest -from effectful.internals.disassembler import reconstruct +from effectful.internals.disassembler import ( + GeneratorExpToForexpr, + IfsToIfAnd, + NameToCall, + ensure_ast, + reconstruct, +) def compile_and_eval( @@ -628,8 +634,6 @@ def test_complex_scenarios(genexpr, globals_dict): ) def test_ensure_ast(value, expected_str): """Test that ensure_ast correctly converts various values to AST nodes.""" - from effectful.internals.disassembler import ensure_ast - result = ensure_ast(value) # Compare the unparsed strings @@ -712,8 +716,6 @@ def test_error_handling(): ) def test_ifs_to_if_and_transformer(genexpr_src, expected_src): """Test IfsToIfAnd transformer converts multiple ifs to single if with and.""" - from effectful.internals.disassembler import IfsToIfAnd - # Parse the source code source_ast = ast.parse(genexpr_src, mode="eval") expected_ast = ast.parse(expected_src, mode="eval") @@ -821,8 +823,6 @@ def test_ifs_to_if_and_transformer(genexpr_src, expected_src): ) def test_ifs_to_if_and_nested_comprehensions(genexpr_src, expected_src): """Test IfsToIfAnd transformer with nested comprehensions.""" - from effectful.internals.disassembler import IfsToIfAnd - # Parse the source code source_ast = ast.parse(genexpr_src, mode="eval") expected_ast = ast.parse(expected_src, mode="eval") @@ -898,8 +898,6 @@ def test_ifs_to_if_and_nested_comprehensions(genexpr_src, expected_src): ) def test_name_to_call_transformer(source_src, varnames, expected_src): """Test NameToCall transformer converts specified names to function calls.""" - from effectful.internals.disassembler import NameToCall - # Parse the source code source_ast = ast.parse(source_src, mode="eval") expected_ast = ast.parse(expected_src, mode="eval") @@ -1014,8 +1012,6 @@ def test_name_to_call_transformer(source_src, varnames, expected_src): ) def test_name_to_call_nested_comprehensions(source_src, varnames, expected_src): """Test NameToCall transformer with nested comprehensions and complex expressions.""" - from effectful.internals.disassembler import NameToCall - # Parse the source code source_ast = ast.parse(source_src, mode="eval") expected_ast = ast.parse(expected_src, mode="eval") @@ -1137,8 +1133,6 @@ def test_name_to_call_nested_comprehensions(source_src, varnames, expected_src): ) def test_generator_exp_to_forexpr_transformer(genexpr_src, expected_src): """Test GeneratorExpToForexpr transformer converts generator expressions to forexpr calls.""" - from effectful.internals.disassembler import GeneratorExpToForexpr - # Parse the source code source_ast = ast.parse(genexpr_src, mode="eval") @@ -1317,8 +1311,6 @@ def test_generator_exp_to_forexpr_transformer(genexpr_src, expected_src): ) def test_generator_exp_to_forexpr_nested_comprehensions(genexpr_src, expected_src): """Test GeneratorExpToForexpr transformer with nested comprehensions and complex expressions.""" - from effectful.internals.disassembler import GeneratorExpToForexpr - # Parse the source code source_ast = ast.parse(genexpr_src, mode="eval") From 42d711e8fff0707b1731e5725ab43002afaae2b0 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 18 Jun 2025 19:25:58 -0400 Subject: [PATCH 07/19] Remove unused IfsToIfAnd transformer --- effectful/internals/disassembler.py | 56 -------- tests/test_internals_disassembler.py | 183 --------------------------- 2 files changed, 239 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 8f76d9ba..94567e79 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1276,62 +1276,6 @@ def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: return ast.fix_missing_locations(ast.Expression(ensure_ast(genexpr))) -class IfsToIfAnd(ast.NodeTransformer): - """ - Transform multiple ifs in comprehensions into a single if with 'and'. - This transformer combines multiple filter conditions (ifs) in comprehension - generators into a single condition using logical 'and'. For example, it transforms: - [x for x in range(10) if x % 2 == 0 if x > 5] - into: - [x for x in range(10) if (x % 2 == 0 and x > 5)] - """ - - def visit_GeneratorExp(self, node: ast.GeneratorExp) -> ast.GeneratorExp: - return self._visit_compexp(node) - - def visit_DictComp(self, node: ast.DictComp) -> ast.DictComp: - return self._visit_compexp(node) - - def visit_SetComp(self, node: ast.SetComp) -> ast.SetComp: - return self._visit_compexp(node) - - def visit_ListComp(self, node: ast.ListComp) -> ast.ListComp: - return self._visit_compexp(node) - - @typing.overload - def _visit_compexp(self, node: ast.DictComp) -> ast.DictComp: ... - @typing.overload - def _visit_compexp(self, node: ast.GeneratorExp) -> ast.GeneratorExp: ... - @typing.overload - def _visit_compexp(self, node: ast.SetComp) -> ast.SetComp: ... - @typing.overload - def _visit_compexp(self, node: ast.ListComp) -> ast.ListComp: ... - - def _visit_compexp(self, node: CompExp): - new_generators = [] - for gen in node.generators: - if len(gen.ifs) > 1: - # Combine multiple ifs into a single if with 'and' - combined_if = gen.ifs[0] - for cond in gen.ifs[1:]: - combined_if = ast.BoolOp(op=ast.And(), values=[combined_if, cond]) - new_gen = ast.comprehension( - target=gen.target, - iter=gen.iter, - ifs=[combined_if], - is_async=gen.is_async, - ) - new_generators.append(new_gen) - else: - new_generators.append(gen) - if isinstance(node, ast.DictComp): - return ast.DictComp( - key=node.key, value=node.value, generators=new_generators - ) - else: - return type(node)(elt=node.elt, generators=new_generators) - - class NameToCall(ast.NodeTransformer): """ Transform variable names into calls to those variables. diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index e02de2ae..00553db9 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -6,7 +6,6 @@ from effectful.internals.disassembler import ( GeneratorExpToForexpr, - IfsToIfAnd, NameToCall, ensure_ast, reconstruct, @@ -661,188 +660,6 @@ def test_error_handling(): # ============================================================================ -@pytest.mark.parametrize( - "genexpr_src,expected_src", - [ - # Generator expressions with multiple ifs - ( - "(x for x in range(10) if x % 2 == 0 if x > 5)", - "(x for x in range(10) if x % 2 == 0 and x > 5)", - ), - ( - "(x for x in range(20) if x > 5 if x < 15 if x % 2 == 0)", - "(x for x in range(20) if (x > 5 and x < 15) and x % 2 == 0)", - ), - # List comprehensions with multiple ifs - ( - "[x for x in range(10) if x % 2 == 0 if x > 5]", - "[x for x in range(10) if x % 2 == 0 and x > 5]", - ), - ( - "[x * 2 for x in range(100) if x > 10 if x < 90 if x % 5 == 0]", - "[x * 2 for x in range(100) if (x > 10 and x < 90) and x % 5 == 0]", - ), - # Set comprehensions with multiple ifs - ( - "{x for x in range(10) if x % 2 == 0 if x > 5}", - "{x for x in range(10) if x % 2 == 0 and x > 5}", - ), - # Dict comprehensions with multiple ifs - ( - "{x: x**2 for x in range(10) if x % 2 == 0 if x > 5}", - "{x: x**2 for x in range(10) if x % 2 == 0 and x > 5}", - ), - # Single if should remain unchanged - ( - "(x for x in range(10) if x % 2 == 0)", - "(x for x in range(10) if x % 2 == 0)", - ), - # No ifs should remain unchanged - ( - "(x for x in range(10))", - "(x for x in range(10))", - ), - # Multiple generators with multiple ifs each - ( - "(x + y for x in range(5) if x % 2 == 0 if x > 1 for y in range(5) if y % 2 == 1 if y < 4)", - "(x + y for x in range(5) if x % 2 == 0 and x > 1 for y in range(5) if y % 2 == 1 and y < 4)", - ), - # Complex expressions with multiple ifs - ( - "((x, y) for x in range(10) if x > 2 if x < 8 for y in range(10) if y != x if y % 2 == 0)", - "((x, y) for x in range(10) if x > 2 and x < 8 for y in range(10) if y != x and y % 2 == 0)", - ), - ], -) -def test_ifs_to_if_and_transformer(genexpr_src, expected_src): - """Test IfsToIfAnd transformer converts multiple ifs to single if with and.""" - # Parse the source code - source_ast = ast.parse(genexpr_src, mode="eval") - expected_ast = ast.parse(expected_src, mode="eval") - - # Apply the transformer - transformer = IfsToIfAnd() - transformed_ast = transformer.visit(source_ast) - - # Compare the unparsed strings - transformed_src = ast.unparse(transformed_ast) - expected_unparsed = ast.unparse(expected_ast) - - assert transformed_src == expected_unparsed, ( - f"Transformation failed:\n" - f"Input: {genexpr_src}\n" - f"Expected: {expected_unparsed}\n" - f"Got: {transformed_src}" - ) - - -@pytest.mark.parametrize( - "genexpr_src,expected_src", - [ - # Nested comprehensions in body - ( - "([x for x in range(i)] for i in range(5))", - "([x for x in range(i)] for i in range(5))", - ), - ( - "([x * 2 for x in range(i) if x % 2 == 0] for i in range(5))", - "([x * 2 for x in range(i) if x % 2 == 0] for i in range(5))", - ), - ( - "({x: x**2 for x in range(i)} for i in range(5))", - "({x: x**2 for x in range(i)} for i in range(5))", - ), - ( - "({x for x in range(i) if x > 2} for i in range(10))", - "({x for x in range(i) if x > 2} for i in range(10))", - ), - # Double nested comprehensions - ( - "([[x for x in range(j)] for j in range(i)] for i in range(3))", - "([[x for x in range(j)] for j in range(i)] for i in range(3))", - ), - ( - "([[x + y for y in range(j)] for j in range(i)] for i in range(3))", - "([[x + y for y in range(j)] for j in range(i)] for i in range(3))", - ), - # Comprehensions with multiple filters that should be combined - ( - "(x for x in range(100) if x > 10 if x < 90 if x % 5 == 0)", - "(x for x in range(100) if (x > 10 and x < 90) and x % 5 == 0)", - ), - ( - "[x**2 for x in range(50) if x % 2 == 0 if x % 3 == 0 if x > 10]", - "[x**2 for x in range(50) if (x % 2 == 0 and x % 3 == 0) and x > 10]", - ), - # Complex nested loops with multiple filters - ( - "((x, y) for x in range(10) if x > 2 if x < 8 for y in range(10) if y != x if y % 2 == 0)", - "((x, y) for x in range(10) if x > 2 and x < 8 for y in range(10) if y != x and y % 2 == 0)", - ), - ( - "[x + y + z for x in range(3) if x > 0 for y in range(3) if y != x for z in range(3) if z != x if z != y]", - "[x + y + z for x in range(3) if x > 0 for y in range(3) if y != x for z in range(3) if (z != x and z != y)]", - ), - # Generator expressions with function calls in body - ( - "(sum([x for x in range(i)]) for i in range(5))", - "(sum([x for x in range(i)]) for i in range(5))", - ), - ( - "(max(x for x in range(i + 1)) for i in range(3))", - "(max(x for x in range(i + 1)) for i in range(3))", - ), - # Nested generator in iterator - ( - "([y for y in range(x)] for x in (x_ + 1 for x_ in range(5)))", - "([y for y in range(x)] for x in (x_ + 1 for x_ in range(5)))", - ), - ( - "(x * 2 for x in [y**2 for y in range(5)])", - "(x * 2 for x in [y**2 for y in range(5)])", - ), - # Mixed comprehension types - ( - "({x: [y for y in range(x)] for x in range(5)})", - "({x: [y for y in range(x)] for x in range(5)})", - ), - ( - "[{x: x**2 for x in range(i)} for i in range(5)]", - "[{x: x**2 for x in range(i)} for i in range(5)]", - ), - # Comprehensions with nested filters - ( - "([x for x in range(i) if x < i] for i in range(5) if i > 0)", - "([x for x in range(i) if x < i] for i in range(5) if i > 0)", - ), - ( - "([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5))", - "([[x for x in range(i + j) if x < i + j] for j in range(i)] for i in range(5))", - ), - ], -) -def test_ifs_to_if_and_nested_comprehensions(genexpr_src, expected_src): - """Test IfsToIfAnd transformer with nested comprehensions.""" - # Parse the source code - source_ast = ast.parse(genexpr_src, mode="eval") - expected_ast = ast.parse(expected_src, mode="eval") - - # Apply the transformer - transformer = IfsToIfAnd() - transformed_ast = transformer.visit(source_ast) - - # Compare the unparsed strings - transformed_src = ast.unparse(transformed_ast) - expected_unparsed = ast.unparse(expected_ast) - - assert transformed_src == expected_unparsed, ( - f"Transformation failed:\n" - f"Input: {genexpr_src}\n" - f"Expected: {expected_unparsed}\n" - f"Got: {transformed_src}" - ) - - @pytest.mark.parametrize( "source_src,varnames,expected_src", [ From 5c00250e8a291aa19320a2e4a2d18750a431e4b0 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 18 Jun 2025 19:30:02 -0400 Subject: [PATCH 08/19] type nit --- tests/test_internals_disassembler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 00553db9..6f591db6 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -1,6 +1,5 @@ import ast -from types import GeneratorType -from typing import Any +from typing import Generator import pytest @@ -14,7 +13,7 @@ def compile_and_eval( node: ast.expr | ast.Expression, globals_dict: dict | None = None -) -> Any: +) -> object: """Compile an AST node and evaluate it.""" if globals_dict is None: globals_dict = {} @@ -32,7 +31,9 @@ def compile_and_eval( def assert_ast_equivalent( - genexpr: GeneratorType, reconstructed_ast: ast.AST, globals_dict: dict | None = None + genexpr: Generator[object, None, None], + reconstructed_ast: ast.AST, + globals_dict: dict | None = None, ): """Assert that a reconstructed AST produces the same results as the original generator.""" # Check AST structure From 008407c49e91ac419f5eab4c0a7f4082b6eb3b4e Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 21 Jun 2025 17:03:09 -0400 Subject: [PATCH 09/19] fix merge --- tests/test_internals_disassembler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 70d2f358..54783975 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -1,4 +1,5 @@ import ast +from types import GeneratorType from typing import Generator import pytest From 65262967eacb41e3f28554ef91e1c029cbff55eb Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 21 Jun 2025 17:51:54 -0400 Subject: [PATCH 10/19] lint --- effectful/internals/disassembler.py | 3 ++- effectful/ops/syntax.py | 8 ++++---- tests/test_internals_disassembler.py | 8 ++++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 6436169f..2184680a 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1918,7 +1918,7 @@ def get_names_from_target(target): ) streams = ast.Dict(keys=[], values=[]) - all_var_names = set() + all_var_names: set[str] = set() for gen in node.generators: # Collect variable names from previous generators @@ -1929,6 +1929,7 @@ def get_names_from_target(target): all_var_names.update(target_names) # Create the value for this generator + value: ast.expr # TODO : Specify type more precisely if gen.ifs: # If there are filters, create a generator expression for the filtered iterator # Note: In the filter conditions, we need to transform previous loop variables diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 57643d34..b89e977f 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -6,7 +6,7 @@ import random import types import typing -from collections.abc import Callable, Generator, Iterable, Mapping +from collections.abc import Callable, Iterable, Mapping from typing import Annotated, Concatenate, Generic, TypeVar import tree @@ -989,8 +989,8 @@ def forexpr( raise NotImplementedError -@defterm.register(Generator) -def _(genexpr: Generator[T, None, None]) -> Expr[Iterable[T]]: +@defterm.register(types.GeneratorType) +def _(genexpr: types.GeneratorType[T, None, None]) -> Expr[Iterable[T]]: from effectful.internals.disassembler import GeneratorExpToForexpr, reconstruct genexpr_ast = reconstruct(genexpr) @@ -1001,7 +1001,7 @@ def _(genexpr: Generator[T, None, None]) -> Expr[Iterable[T]]: filename=forexpr_name, mode="eval", ) - return eval(forexpr_code, genexpr.gi_frame.f_globals, genexpr.gi_frame.f_locals) # type: ignore + return eval(forexpr_code, genexpr.gi_frame.f_globals, genexpr.gi_frame.f_locals) def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool: diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 54783975..c6affde8 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -1,6 +1,5 @@ import ast -from types import GeneratorType -from typing import Generator +from collections.abc import Generator import pytest import tree @@ -32,11 +31,11 @@ def compile_and_eval( return eval(code, globals_dict) -def materialize(genexpr: GeneratorType) -> tree.Structure: +def materialize(genexpr: Generator[object, None, None]) -> tree.Structure: """Materialize a nested generator expression to a nested list.""" def _materialize(genexpr): - if isinstance(genexpr, GeneratorType): + if isinstance(genexpr, Generator): return tree.map_structure(_materialize, list(genexpr)) elif tree.is_nested(genexpr): return tree.map_structure(_materialize, genexpr) @@ -77,6 +76,7 @@ def assert_ast_equivalent( # Compile and evaluate the reconstructed AST reconstructed_gen = compile_and_eval(reconstructed_ast, globals_dict) + assert isinstance(reconstructed_gen, Generator) reconstructed_list = materialize(reconstructed_gen) assert reconstructed_list == original_list, ( f"AST produced {reconstructed_list}, expected {original_list}" From a9cd95855f1712ca2e2762b92c947ccf79ec0741 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 26 Jun 2025 10:01:49 -0400 Subject: [PATCH 11/19] defdata for iterable/iterator --- effectful/ops/syntax.py | 14 ++++++++++++++ tests/test_ops_syntax.py | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index b7377693..55107dba 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1108,6 +1108,20 @@ def trace(value: Callable[P, T]) -> Callable[P, T]: return deffn(body, *bound_sig.args, **bound_sig.kwargs) +@defdata.register(collections.abc.Iterable) +class _IterableTerm(Generic[T], _BaseTerm[collections.abc.Iterable[T]]): + @defop + def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: + raise NotImplementedError + + +@defdata.register(collections.abc.Iterator) +class _IteratorTerm(Generic[T], _IterableTerm[T]): + @defop + def __next__(self: collections.abc.Iterator[T]) -> T: + raise NotImplementedError + + def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool: """Syntactic equality, ignoring the interpretation of the terms. diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 93402f87..30afe78c 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -1,12 +1,12 @@ import functools import inspect -from collections.abc import Callable, Mapping +from collections.abc import Callable, Iterable, Iterator, Mapping from typing import Annotated, ClassVar, TypeVar import pytest import effectful.handlers.numbers # noqa: F401 -from effectful.ops.semantics import call, evaluate, fvsof, handler +from effectful.ops.semantics import call, evaluate, fvsof, handler, typeof from effectful.ops.syntax import ( Scoped, _CustomSingleDispatchCallable, @@ -484,3 +484,32 @@ def _(self, x: bool) -> bool: # Test that the method can be called with a handler with handler({MyClass.my_singledispatch: lambda self, x: x + 6}): assert instance.my_singledispatch(5) == 11 + + +def test_defdata_iterable(): + from effectful.ops.syntax import _IterableTerm, _IteratorTerm + + @defop + def cons_iterable(*args: int) -> Iterable[int]: + raise NotImplementedError + + tm = cons_iterable(1, 2, 3) + assert isinstance(tm, Term) + assert isinstance(tm, Iterable) + assert issubclass(typeof(tm), Iterable) + assert tm.op is cons_iterable + assert tm.args == (1, 2, 3) + + tm_iter = iter(tm) + assert isinstance(tm_iter, Term) + assert isinstance(tm_iter, Iterator) + assert issubclass(typeof(tm_iter), Iterator) + assert tm_iter.op is _IterableTerm.__iter__ + + tm_iter_next = next(tm_iter) + assert isinstance(tm_iter_next, Term) + # assert isinstance(tm_iter_next, numbers.Number) # TODO + # assert issubclass(typeof(tm_iter_next), numbers.Number) + assert tm_iter_next.op is _IteratorTerm.__next__ + + assert list(tm.args) == [1, 2, 3] From 1f213e1a4af8b7b534950d757261c491c1551172 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 27 Jun 2025 09:59:07 -0400 Subject: [PATCH 12/19] ops --- effectful/ops/syntax.py | 4 ++++ tests/test_ops_syntax.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 55107dba..bea073bf 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1122,6 +1122,10 @@ def __next__(self: collections.abc.Iterator[T]) -> T: raise NotImplementedError +iter_ = _IterableTerm.__iter__ +next_ = _IteratorTerm.__next__ + + def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool: """Syntactic equality, ignoring the interpretation of the terms. diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 30afe78c..7ccd5dd0 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -13,6 +13,8 @@ deffn, defop, defterm, + iter_, + next_, ) from effectful.ops.types import Operation, Term @@ -487,8 +489,6 @@ def _(self, x: bool) -> bool: def test_defdata_iterable(): - from effectful.ops.syntax import _IterableTerm, _IteratorTerm - @defop def cons_iterable(*args: int) -> Iterable[int]: raise NotImplementedError @@ -504,12 +504,12 @@ def cons_iterable(*args: int) -> Iterable[int]: assert isinstance(tm_iter, Term) assert isinstance(tm_iter, Iterator) assert issubclass(typeof(tm_iter), Iterator) - assert tm_iter.op is _IterableTerm.__iter__ + assert tm_iter.op is iter_ tm_iter_next = next(tm_iter) assert isinstance(tm_iter_next, Term) # assert isinstance(tm_iter_next, numbers.Number) # TODO # assert issubclass(typeof(tm_iter_next), numbers.Number) - assert tm_iter_next.op is _IteratorTerm.__next__ + assert tm_iter_next.op is next_ assert list(tm.args) == [1, 2, 3] From b0da173a64f532d53b2ef7e1a9bf1c9c8d449bf2 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 27 Jun 2025 14:02:07 -0400 Subject: [PATCH 13/19] forexpr --- effectful/internals/disassembler.py | 6 +++--- effectful/ops/syntax.py | 16 +++++++++++++--- tests/test_ops_syntax.py | 9 +++++++++ 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 2184680a..71b28f89 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -605,7 +605,7 @@ def handle_store_fast( new_stack = state.stack[:-1] new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) + new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Load()) return replace(state, stack=new_stack, result=new_result) @@ -628,7 +628,7 @@ def handle_store_deref( new_stack = state.stack[:-1] new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) + new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Load()) return replace(state, stack=new_stack, result=new_result) @@ -652,7 +652,7 @@ def handle_store_fast_load_fast( new_stack = state.stack[:-1] + [ast.Name(id=load_name, ctx=ast.Load())] new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=store_name, ctx=ast.Store()) + new_result.generators[-1].target = ast.Name(id=store_name, ctx=ast.Load()) return replace(state, stack=new_stack, result=new_result) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index eaf99d6e..5df1534e 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1132,21 +1132,31 @@ def _(genexpr: types.GeneratorType[T, None, None]) -> Expr[Iterable[T]]: filename=forexpr_name, mode="eval", ) - return eval(forexpr_code, genexpr.gi_frame.f_globals, genexpr.gi_frame.f_locals) + return eval( + forexpr_code, + genexpr.gi_frame.f_globals | {"forexpr": forexpr}, + genexpr.gi_frame.f_locals, # TODO infer types and construct stream variables + ) @defdata.register(collections.abc.Iterable) class _IterableTerm(Generic[T], _BaseTerm[collections.abc.Iterable[T]]): @defop def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: - raise NotImplementedError + if not isinstance(self, Term): + return iter(self) + else: + raise NotImplementedError @defdata.register(collections.abc.Iterator) class _IteratorTerm(Generic[T], _IterableTerm[T]): @defop def __next__(self: collections.abc.Iterator[T]) -> T: - raise NotImplementedError + if not isinstance(self, Term): + return next(self) + else: + raise NotImplementedError iter_ = _IterableTerm.__iter__ diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 7ccd5dd0..c41b6e85 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -13,6 +13,7 @@ deffn, defop, defterm, + forexpr, iter_, next_, ) @@ -513,3 +514,11 @@ def cons_iterable(*args: int) -> Iterable[int]: assert tm_iter_next.op is next_ assert list(tm.args) == [1, 2, 3] + + +def test_defterm_genexpr(): + xs = (x + 1 for x in range(5)) + + tm = defterm(xs) + assert isinstance(tm, Term) + assert tm.op is forexpr From 8c1a8d97dc78be0f16669b6bada551392ee233a2 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 27 Jun 2025 14:10:38 -0400 Subject: [PATCH 14/19] forexpr to defstream --- effectful/internals/disassembler.py | 18 ++-- effectful/ops/syntax.py | 14 ++-- tests/test_internals_disassembler.py | 120 +++++++++++++-------------- tests/test_ops_syntax.py | 4 +- 4 files changed, 77 insertions(+), 79 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 71b28f89..16fa6993 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1855,14 +1855,14 @@ def visit_Name(self, node: ast.Name) -> ast.Call | ast.Name: return node -class GeneratorExpToForexpr(ast.NodeTransformer): +class GeneratorExpToDefstream(ast.NodeTransformer): """ - Transform generator expressions into calls to `forexpr`. + Transform generator expressions into calls to `defstream`. This transformer converts generator expressions of the form: (expr for var in iter) - into calls to `forexpr`: - forexpr(expr, {var: lambda: iter}) + into calls to `defstream`: + defstream(expr, {var: lambda: iter}) It supports: - Multiple nested loops @@ -1874,22 +1874,22 @@ class GeneratorExpToForexpr(ast.NodeTransformer): >>> import ast >>> source = "(x * 2 for x in range(10))" >>> tree = ast.parse(source, mode='eval') - >>> transformer = GeneratorExpToForexpr() + >>> transformer = GeneratorExpTodefstream() >>> transformed = transformer.visit(tree) >>> ast.unparse(transformed) - 'forexpr(x() * 2, {x: lambda: range(10)})' + 'defstream(x() * 2, {x: lambda: range(10)})' >>> source = "(x for x in range(10) if x % 2 == 0)" >>> tree = ast.parse(source, mode='eval') >>> transformed = transformer.visit(tree) >>> ast.unparse(transformed) - 'forexpr(x(), {x: (x for x in range(10) if x % 2 == 0)})' + 'defstream(x(), {x: (x for x in range(10) if x % 2 == 0)})' >>> source = "((x, y) for x, y in pairs)" >>> tree = ast.parse(source, mode='eval') >>> transformed = transformer.visit(tree) >>> ast.unparse(transformed) - 'forexpr((x(), y()), {(x, y): lambda: pairs})' + 'defstream((x(), y()), {(x, y): lambda: pairs})' """ @@ -1971,7 +1971,7 @@ def get_names_from_target(target): body = self.visit(body) # Recursively transform nested generator expressions return ast.Call( - func=ast.Name(id="forexpr", ctx=ast.Load()), + func=ast.Name(id="defstream", ctx=ast.Load()), args=[body, streams], keywords=[], ) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 5df1534e..b17affff 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1110,11 +1110,9 @@ def trace(value: Callable[P, T]) -> Callable[P, T]: @defop -def forexpr( +def defstream( body: Annotated[T, Scoped[A | B]], - streams: Annotated[ - Mapping[Operation[..., S], Callable[..., Iterable[S]]], Scoped[B] - ], + streams: Annotated[Mapping[Operation[[], S], Iterable[S]], Scoped[B]], ) -> Annotated[Iterable[T], Scoped[A]]: """A higher-order operation that represents a for-expression.""" raise NotImplementedError @@ -1122,11 +1120,11 @@ def forexpr( @defterm.register(types.GeneratorType) def _(genexpr: types.GeneratorType[T, None, None]) -> Expr[Iterable[T]]: - from effectful.internals.disassembler import GeneratorExpToForexpr, reconstruct + from effectful.internals.disassembler import GeneratorExpToDefstream, reconstruct genexpr_ast = reconstruct(genexpr) - forexpr_ast = GeneratorExpToForexpr().visit(genexpr_ast) - forexpr_name = ".".join(genexpr.gi_code.co_name.split(".")[:-1] + [""]) + forexpr_ast = GeneratorExpToDefstream().visit(genexpr_ast) + forexpr_name = ".".join(genexpr.gi_code.co_name.split(".")[:-1] + [""]) forexpr_code = compile( ast.fix_missing_locations(forexpr_ast), filename=forexpr_name, @@ -1134,7 +1132,7 @@ def _(genexpr: types.GeneratorType[T, None, None]) -> Expr[Iterable[T]]: ) return eval( forexpr_code, - genexpr.gi_frame.f_globals | {"forexpr": forexpr}, + genexpr.gi_frame.f_globals | {"defstream": defstream}, genexpr.gi_frame.f_locals, # TODO infer types and construct stream variables ) diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index c6affde8..41058a65 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -5,7 +5,7 @@ import tree from effectful.internals.disassembler import ( - GeneratorExpToForexpr, + GeneratorExpToDefstream, NameToCall, ensure_ast, reconstruct, @@ -925,105 +925,105 @@ def test_name_to_call_nested_comprehensions(source_src, varnames, expected_src): # Simple generator expressions ( "(x for x in range(10))", - "forexpr(x(), {x: lambda: range(10)})", + "defstream(x(), {x: lambda: range(10)})", ), ( "(x * 2 for x in range(10))", - "forexpr(x() * 2, {x: lambda: range(10)})", + "defstream(x() * 2, {x: lambda: range(10)})", ), ( "(x + 1 for x in items)", - "forexpr(x() + 1, {x: lambda: items})", + "defstream(x() + 1, {x: lambda: items})", ), # Complex expressions ( "(x ** 2 + 2 * x + 1 for x in range(5))", - "forexpr(x() ** 2 + 2 * x() + 1, {x: lambda: range(5)})", + "defstream(x() ** 2 + 2 * x() + 1, {x: lambda: range(5)})", ), ( "(f(x) for x in data)", - "forexpr(f(x()), {x: lambda: data})", + "defstream(f(x()), {x: lambda: data})", ), # Multiple nested loops ( "(x + y for x in range(3) for y in range(4))", - "forexpr(x() + y(), {x: lambda: range(3), y: lambda: range(4)})", + "defstream(x() + y(), {x: lambda: range(3), y: lambda: range(4)})", ), ( "(x * y for x in items1 for y in items2)", - "forexpr(x() * y(), {x: lambda: items1, y: lambda: items2})", + "defstream(x() * y(), {x: lambda: items1, y: lambda: items2})", ), # Nested loops with dependencies ( "(x + y for x in range(3) for y in range(x))", - "forexpr(x() + y(), {x: lambda: range(3), y: lambda: range(x())})", + "defstream(x() + y(), {x: lambda: range(3), y: lambda: range(x())})", ), ( "((x, y) for x in range(3) for y in range(x, 5))", - "forexpr((x(), y()), {x: lambda: range(3), y: lambda: range(x(), 5)})", + "defstream((x(), y()), {x: lambda: range(3), y: lambda: range(x(), 5)})", ), # Triple nested loops ( "(x + y + z for x in range(2) for y in range(2) for z in range(2))", - "forexpr(x() + y() + z(), {x: lambda: range(2), y: lambda: range(2), z: lambda: range(2)})", + "defstream(x() + y() + z(), {x: lambda: range(2), y: lambda: range(2), z: lambda: range(2)})", ), # Complex iterators ( "(x for x in [1, 2, 3])", - "forexpr(x(), {x: lambda: [1, 2, 3]})", + "defstream(x(), {x: lambda: [1, 2, 3]})", ), ( "(x for x in list(range(5)))", - "forexpr(x(), {x: lambda: list(range(5))})", + "defstream(x(), {x: lambda: list(range(5))})", ), # Expressions with function calls on iterators ( "(x for x in sorted(items))", - "forexpr(x(), {x: lambda: sorted(items)})", + "defstream(x(), {x: lambda: sorted(items)})", ), # Generator expressions with filters ( "(x for x in range(10) if x % 2 == 0)", - "forexpr(x(), {x: (x for x in range(10) if x % 2 == 0)})", + "defstream(x(), {x: (x for x in range(10) if x % 2 == 0)})", ), ( "(x + y for x in range(3) if x > 0 for y in range(3))", - "forexpr(x() + y(), {x: (x for x in range(3) if x > 0), y: lambda: range(3)})", + "defstream(x() + y(), {x: (x for x in range(3) if x > 0), y: lambda: range(3)})", ), ( "(x * 2 for x in items if x > 5)", - "forexpr(x() * 2, {x: (x for x in items if x > 5)})", + "defstream(x() * 2, {x: (x for x in items if x > 5)})", ), ( "(x + y for x in range(5) for y in range(5) if x < y)", - "forexpr(x() + y(), {x: lambda: range(5), y: (y for y in range(5) if x() < y)})", + "defstream(x() + y(), {x: lambda: range(5), y: (y for y in range(5) if x() < y)})", ), ( "(x for x in range(20) if x % 2 == 0 if x % 3 == 0)", - "forexpr(x(), {x: (x for x in range(20) if x % 2 == 0 if x % 3 == 0)})", + "defstream(x(), {x: (x for x in range(20) if x % 2 == 0 if x % 3 == 0)})", ), # Generator expressions with unpacking ( "((x, y) for x, y in pairs)", - "forexpr((x(), y()), {(x, y): lambda: pairs})", + "defstream((x(), y()), {(x, y): lambda: pairs})", ), ( "(a + b for a, b in zip(list1, list2))", - "forexpr(a() + b(), {(a, b): lambda: zip(list1, list2)})", + "defstream(a() + b(), {(a, b): lambda: zip(list1, list2)})", ), ( "(x + y + z for x, (y, z) in nested_pairs)", - "forexpr(x() + y() + z(), {(x, (y, z)): lambda: nested_pairs})", + "defstream(x() + y() + z(), {(x, (y, z)): lambda: nested_pairs})", ), ], ) -def test_generator_exp_to_forexpr_transformer(genexpr_src, expected_src): - """Test GeneratorExpToForexpr transformer converts generator expressions to forexpr calls.""" +def test_generator_exp_to_defstream_transformer(genexpr_src, expected_src): + """Test GeneratorExpToDefstream transformer converts generator expressions to defstream calls.""" # Parse the source code source_ast = ast.parse(genexpr_src, mode="eval") # Apply the transformer - transformer = GeneratorExpToForexpr() + transformer = GeneratorExpToDefstream() # For test cases that should raise NotImplementedError if expected_src == "NOT_APPLICABLE": @@ -1052,156 +1052,156 @@ def test_generator_exp_to_forexpr_transformer(genexpr_src, expected_src): # Generator expressions yielding comprehensions ( "([x for x in range(i)] for i in range(5))", - "forexpr([x for x in range(i())], {i: lambda: range(5)})", + "defstream([x for x in range(i())], {i: lambda: range(5)})", ), ( "({x: x**2 for x in range(i)} for i in range(5))", - "forexpr({x: x**2 for x in range(i())}, {i: lambda: range(5)})", + "defstream({x: x**2 for x in range(i())}, {i: lambda: range(5)})", ), ( "({x for x in range(i) if x > 2} for i in range(10))", - "forexpr({x for x in range(i()) if x > 2}, {i: lambda: range(10)})", + "defstream({x for x in range(i()) if x > 2}, {i: lambda: range(10)})", ), # Double nested comprehensions ( "([[x for x in range(j)] for j in range(i)] for i in range(3))", - "forexpr([[x for x in range(j)] for j in range(i())], {i: lambda: range(3)})", + "defstream([[x for x in range(j)] for j in range(i())], {i: lambda: range(3)})", ), ( "([[x + y for y in range(j)] for j in range(i)] for i in range(3))", - "forexpr([[x + y for y in range(j)] for j in range(i())], {i: lambda: range(3)})", + "defstream([[x + y for y in range(j)] for j in range(i())], {i: lambda: range(3)})", ), # Generator with comprehension in iterator ( "(x * 2 for x in [y**2 for y in range(5)])", - "forexpr(x() * 2, {x: lambda: [y**2 for y in range(5)]})", + "defstream(x() * 2, {x: lambda: [y**2 for y in range(5)]})", ), ( "(x for x in {y: y**2 for y in range(3)}.values())", - "forexpr(x(), {x: lambda: {y: y**2 for y in range(3)}.values()})", + "defstream(x(), {x: lambda: {y: y**2 for y in range(3)}.values()})", ), # Complex expressions with nested calls ( "(sum([x for x in range(i)]) for i in range(5))", - "forexpr(sum([x for x in range(i())]), {i: lambda: range(5)})", + "defstream(sum([x for x in range(i())]), {i: lambda: range(5)})", ), ( "(max(x for x in range(i + 1)) for i in range(3))", - "forexpr(max(forexpr(x(), {x: lambda: range(i() + 1)})), {i: lambda: range(3)})", + "defstream(max(defstream(x(), {x: lambda: range(i() + 1)})), {i: lambda: range(3)})", ), ( "(list(enumerate(x + 1 for x in range(i + 1))) for i in range(3))", - "forexpr(list(enumerate(forexpr(x() + 1, {x: lambda: range(i() + 1)}))), {i: lambda: range(3)})", + "defstream(list(enumerate(defstream(x() + 1, {x: lambda: range(i() + 1)}))), {i: lambda: range(3)})", ), # Nested generators in iterator ( "([y for y in range(x)] for x in (z + 1 for z in range(5)))", - "forexpr([y for y in range(x())], {x: lambda: forexpr(z() + 1, {z: lambda: range(5)})})", + "defstream([y for y in range(x())], {x: lambda: defstream(z() + 1, {z: lambda: range(5)})})", ), # Complex filters with nested comprehensions ( "(x for x in range(10) if x in [y**2 for y in range(5)])", - "forexpr(x(), {x: (x for x in range(10) if x in [y**2 for y in range(5)])})", + "defstream(x(), {x: (x for x in range(10) if x in [y**2 for y in range(5)])})", ), ( "([x, y] for x in range(3) if x > 0 for y in range(3) if y in [z for z in range(x)])", - "forexpr([x(), y()], {x: (x for x in range(3) if x > 0), y: (y for y in range(3) if y in [z for z in range(x())])})", + "defstream([x(), y()], {x: (x for x in range(3) if x > 0), y: (y for y in range(3) if y in [z for z in range(x())])})", ), # Multiple filters with dependencies ( "(x + y for x in range(5) if x > 1 for y in range(x) if y < x - 1)", - "forexpr(x() + y(), {x: (x for x in range(5) if x > 1), y: (y for y in range(x()) if y < x() - 1)})", + "defstream(x() + y(), {x: (x for x in range(5) if x > 1), y: (y for y in range(x()) if y < x() - 1)})", ), ( "((x, y, z) for x in range(3) for y in range(x, 5) if y > x for z in range(y) if z < y)", - "forexpr((x(), y(), z()), {x: lambda: range(3), y: (y for y in range(x(), 5) if y > x()), z: (z for z in range(y()) if z < y())})", + "defstream((x(), y(), z()), {x: lambda: range(3), y: (y for y in range(x(), 5) if y > x()), z: (z for z in range(y()) if z < y())})", ), # Unpacking with nested structures ( "(a + b + c for (a, b), c in [((1, 2), 3), ((4, 5), 6)])", - "forexpr(a() + b() + c(), {((a, b), c): lambda: [((1, 2), 3), ((4, 5), 6)]})", + "defstream(a() + b() + c(), {((a, b), c): lambda: [((1, 2), 3), ((4, 5), 6)]})", ), ( "(x + sum(lst) for x, lst in [(1, [2, 3]), (4, [5, 6])])", - "forexpr(x() + sum(lst()), {(x, lst): lambda: [(1, [2, 3]), (4, [5, 6])]})", + "defstream(x() + sum(lst()), {(x, lst): lambda: [(1, [2, 3]), (4, [5, 6])]})", ), # Complex iterators ( "(x for x in sorted([y**2 for y in range(5)]))", - "forexpr(x(), {x: lambda: sorted([y**2 for y in range(5)])})", + "defstream(x(), {x: lambda: sorted([y**2 for y in range(5)])})", ), ( "(item for sublist in [[1, 2], [3, 4], [5, 6]] for item in sublist)", - "forexpr(item(), {sublist: lambda: [[1, 2], [3, 4], [5, 6]], item: lambda: sublist()})", + "defstream(item(), {sublist: lambda: [[1, 2], [3, 4], [5, 6]], item: lambda: sublist()})", ), # Expressions with method calls ( "(s.upper() for s in ['hello', 'world'] if s.startswith('h'))", - "forexpr(s().upper(), {s: (s for s in ['hello', 'world'] if s.startswith('h'))})", + "defstream(s().upper(), {s: (s for s in ['hello', 'world'] if s.startswith('h'))})", ), ( "(obj.value for obj in objects if hasattr(obj, 'value'))", - "forexpr(obj().value, {obj: (obj for obj in objects if hasattr(obj, 'value'))})", + "defstream(obj().value, {obj: (obj for obj in objects if hasattr(obj, 'value'))})", ), # CRITICAL: Generator expressions yielding generator expressions ( "((x for x in range(i)) for i in range(5))", - "forexpr(forexpr(x(), {x: lambda: range(i())}), {i: lambda: range(5)})", + "defstream(defstream(x(), {x: lambda: range(i())}), {i: lambda: range(5)})", ), ( "((x * 2 for x in range(i)) for i in range(3))", - "forexpr(forexpr(x() * 2, {x: lambda: range(i())}), {i: lambda: range(3)})", + "defstream(defstream(x() * 2, {x: lambda: range(i())}), {i: lambda: range(3)})", ), ( "((x + y for x in range(3) for y in range(x)) for i in range(2))", - "forexpr(forexpr(x() + y(), {x: lambda: range(3), y: lambda: range(x())}), {i: lambda: range(2)})", + "defstream(defstream(x() + y(), {x: lambda: range(3), y: lambda: range(x())}), {i: lambda: range(2)})", ), # Generator yielding filtered generator ( "((x for x in range(10) if x % 2 == 0) for i in range(3))", - "forexpr(forexpr(x(), {x: (x for x in range(10) if x % 2 == 0)}), {i: lambda: range(3)})", + "defstream(defstream(x(), {x: (x for x in range(10) if x % 2 == 0)}), {i: lambda: range(3)})", ), ( "((x for x in range(i) if x > 0) for i in range(5))", - "forexpr(forexpr(x(), {x: (x for x in range(i()) if x > 0)}), {i: lambda: range(5)})", + "defstream(defstream(x(), {x: (x for x in range(i()) if x > 0)}), {i: lambda: range(5)})", ), # Nested generators with multiple levels ( "((y for y in (x for x in range(i))) for i in range(3))", - "forexpr(forexpr(y(), {y: lambda: forexpr(x(), {x: lambda: range(i())})}), {i: lambda: range(3)})", + "defstream(defstream(y(), {y: lambda: defstream(x(), {x: lambda: range(i())})}), {i: lambda: range(3)})", ), ( "(((x + y for x in range(2)) for y in range(3)) for z in range(4))", - "forexpr(forexpr(forexpr(x() + y(), {x: lambda: range(2)}), {y: lambda: range(3)}), {z: lambda: range(4)})", + "defstream(defstream(defstream(x() + y(), {x: lambda: range(2)}), {y: lambda: range(3)}), {z: lambda: range(4)})", ), # Generator with unpacking yielding generator ( "((x + b for x in range(a)) for a, b in [(2, 3), (4, 5)])", - "forexpr(forexpr(x() + b(), {x: lambda: range(a())}), {(a, b): lambda: [(2, 3), (4, 5)]})", + "defstream(defstream(x() + b(), {x: lambda: range(a())}), {(a, b): lambda: [(2, 3), (4, 5)]})", ), # Complex case: generator yielding generator with filters and dependencies ( "((x + y for x in range(i) if x > 0 for y in range(x)) for i in range(5) if i > 2)", - "forexpr(forexpr(x() + y(), {x: (x for x in range(i()) if x > 0), y: lambda: range(x())}), {i: (i for i in range(5) if i > 2)})", + "defstream(defstream(x() + y(), {x: (x for x in range(i()) if x > 0), y: lambda: range(x())}), {i: (i for i in range(5) if i > 2)})", ), # Generator expression yielding sum of generator expression ( "(sum(x for x in range(i)) for i in range(5))", - "forexpr(sum(forexpr(x(), {x: lambda: range(i())})), {i: lambda: range(5)})", + "defstream(sum(defstream(x(), {x: lambda: range(i())})), {i: lambda: range(5)})", ), ( "(max(x * 2 for x in range(i) if x > 0) for i in range(10))", - "forexpr(max(forexpr(x() * 2, {x: (x for x in range(i()) if x > 0)})), {i: lambda: range(10)})", + "defstream(max(defstream(x() * 2, {x: (x for x in range(i()) if x > 0)})), {i: lambda: range(10)})", ), ], ) -def test_generator_exp_to_forexpr_nested_comprehensions(genexpr_src, expected_src): - """Test GeneratorExpToForexpr transformer with nested comprehensions and complex expressions.""" +def test_generator_exp_to_defstream_nested_comprehensions(genexpr_src, expected_src): + """Test GeneratorExpToDefstream transformer with nested comprehensions and complex expressions.""" # Parse the source code source_ast = ast.parse(genexpr_src, mode="eval") # Apply the transformer - transformer = GeneratorExpToForexpr() + transformer = GeneratorExpToDefstream() transformed_ast = transformer.visit(source_ast) expected_ast = ast.parse(expected_src, mode="eval") diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index c41b6e85..891d3939 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -12,8 +12,8 @@ _CustomSingleDispatchCallable, deffn, defop, + defstream, defterm, - forexpr, iter_, next_, ) @@ -521,4 +521,4 @@ def test_defterm_genexpr(): tm = defterm(xs) assert isinstance(tm, Term) - assert tm.op is forexpr + assert tm.op is defstream From 3e4ca85f96ab42fe20869e712e1f1519371a6385 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 27 Jun 2025 14:12:44 -0400 Subject: [PATCH 15/19] use correct Name context --- effectful/internals/disassembler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 16fa6993..50229478 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -605,7 +605,7 @@ def handle_store_fast( new_stack = state.stack[:-1] new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Load()) + new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) return replace(state, stack=new_stack, result=new_result) @@ -628,7 +628,7 @@ def handle_store_deref( new_stack = state.stack[:-1] new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Load()) + new_result.generators[-1].target = ast.Name(id=var_name, ctx=ast.Store()) return replace(state, stack=new_stack, result=new_result) @@ -652,7 +652,7 @@ def handle_store_fast_load_fast( new_stack = state.stack[:-1] + [ast.Name(id=load_name, ctx=ast.Load())] new_result: CompExp = copy.deepcopy(state.result) - new_result.generators[-1].target = ast.Name(id=store_name, ctx=ast.Load()) + new_result.generators[-1].target = ast.Name(id=store_name, ctx=ast.Store()) return replace(state, stack=new_stack, result=new_result) From cf9b0df2ba71f95c97b1c834fb37f683e41cea13 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 27 Jun 2025 14:25:06 -0400 Subject: [PATCH 16/19] store and load --- effectful/internals/disassembler.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index 50229478..a7a7f0bc 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1851,10 +1851,27 @@ def __init__(self, varnames: set[str]): def visit_Name(self, node: ast.Name) -> ast.Call | ast.Name: if node.id in self.varnames and isinstance(node.ctx, ast.Load): return ast.Call(node, args=[], keywords=[]) + elif node.id in self.varnames and isinstance(node.ctx, ast.Store): + # If it's a store context, we don't want to transform it into a call + return ast.Name(id=node.id, ctx=ast.Load()) else: return node +class NameStoreToLoad(ast.NodeTransformer): + """ + Transform variable names in store context to load context. + This transformer changes all variable names in store context (e.g., assignments) + to load context, effectively treating them as read-only variables. + """ + + def visit_Name(self, node: ast.Name) -> ast.Name: + return ast.Name(id=node.id, ctx=ast.Load()) + + def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple: + return ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=ast.Load()) + + class GeneratorExpToDefstream(ast.NodeTransformer): """ Transform generator expressions into calls to `defstream`. @@ -1962,7 +1979,7 @@ def get_names_from_target(target): body=self.visit(NameToCall(prev_var_names).visit(gen.iter)), ) - streams.keys.append(gen.target) + streams.keys.append(NameStoreToLoad().visit(gen.target)) streams.values.append(value) # Transform the body expression From 00011babd49f0b0367fa2617c81b61d2e8ce118e Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 27 Jun 2025 15:40:53 -0400 Subject: [PATCH 17/19] register subclass --- effectful/ops/syntax.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index b17affff..0abac480 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1138,6 +1138,7 @@ def _(genexpr: types.GeneratorType[T, None, None]) -> Expr[Iterable[T]]: @defdata.register(collections.abc.Iterable) +@collections.abc.Iterable.register class _IterableTerm(Generic[T], _BaseTerm[collections.abc.Iterable[T]]): @defop def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: @@ -1148,6 +1149,7 @@ def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: @defdata.register(collections.abc.Iterator) +@collections.abc.Iterator.register class _IteratorTerm(Generic[T], _IterableTerm[T]): @defop def __next__(self: collections.abc.Iterator[T]) -> T: From 19607db02c2c87d2367a9d687ddb1b98ac9ebd0e Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 27 Jun 2025 15:41:43 -0400 Subject: [PATCH 18/19] reorder --- effectful/ops/syntax.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 0abac480..aa600286 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1118,25 +1118,6 @@ def defstream( raise NotImplementedError -@defterm.register(types.GeneratorType) -def _(genexpr: types.GeneratorType[T, None, None]) -> Expr[Iterable[T]]: - from effectful.internals.disassembler import GeneratorExpToDefstream, reconstruct - - genexpr_ast = reconstruct(genexpr) - forexpr_ast = GeneratorExpToDefstream().visit(genexpr_ast) - forexpr_name = ".".join(genexpr.gi_code.co_name.split(".")[:-1] + [""]) - forexpr_code = compile( - ast.fix_missing_locations(forexpr_ast), - filename=forexpr_name, - mode="eval", - ) - return eval( - forexpr_code, - genexpr.gi_frame.f_globals | {"defstream": defstream}, - genexpr.gi_frame.f_locals, # TODO infer types and construct stream variables - ) - - @defdata.register(collections.abc.Iterable) @collections.abc.Iterable.register class _IterableTerm(Generic[T], _BaseTerm[collections.abc.Iterable[T]]): @@ -1163,6 +1144,25 @@ def __next__(self: collections.abc.Iterator[T]) -> T: next_ = _IteratorTerm.__next__ +@defterm.register(types.GeneratorType) +def _(genexpr: types.GeneratorType[T, None, None]) -> Expr[Iterable[T]]: + from effectful.internals.disassembler import GeneratorExpToDefstream, reconstruct + + genexpr_ast = reconstruct(genexpr) + forexpr_ast = GeneratorExpToDefstream().visit(genexpr_ast) + forexpr_name = ".".join(genexpr.gi_code.co_name.split(".")[:-1] + [""]) + forexpr_code = compile( + ast.fix_missing_locations(forexpr_ast), + filename=forexpr_name, + mode="eval", + ) + return eval( + forexpr_code, + genexpr.gi_frame.f_globals | {"defstream": defstream}, + genexpr.gi_frame.f_locals, # TODO infer types and construct stream variables + ) + + def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool: """Syntactic equality, ignoring the interpretation of the terms. From 62aeaa0850cdb260e6e70c1b9ae5b6432ff85872 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 27 Jun 2025 15:48:26 -0400 Subject: [PATCH 19/19] register not needed --- effectful/internals/disassembler.py | 13 +--- effectful/ops/syntax.py | 2 - tests/test_internals_disassembler.py | 90 ++++++++++++++-------------- 3 files changed, 47 insertions(+), 58 deletions(-) diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index a7a7f0bc..b33ae84d 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1967,17 +1967,8 @@ def get_names_from_target(target): ) value = filtered_gen else: - # No filters, create a lambda - value = ast.Lambda( - args=ast.arguments( - posonlyargs=[], - args=[], - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), - body=self.visit(NameToCall(prev_var_names).visit(gen.iter)), - ) + # No filters, create an expression + value = self.visit(NameToCall(prev_var_names).visit(gen.iter)) streams.keys.append(NameStoreToLoad().visit(gen.target)) streams.values.append(value) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index aa600286..b50fb354 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1119,7 +1119,6 @@ def defstream( @defdata.register(collections.abc.Iterable) -@collections.abc.Iterable.register class _IterableTerm(Generic[T], _BaseTerm[collections.abc.Iterable[T]]): @defop def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: @@ -1130,7 +1129,6 @@ def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: @defdata.register(collections.abc.Iterator) -@collections.abc.Iterator.register class _IteratorTerm(Generic[T], _IterableTerm[T]): @defop def __next__(self: collections.abc.Iterator[T]) -> T: diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index 41058a65..5da3e998 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -925,61 +925,61 @@ def test_name_to_call_nested_comprehensions(source_src, varnames, expected_src): # Simple generator expressions ( "(x for x in range(10))", - "defstream(x(), {x: lambda: range(10)})", + "defstream(x(), {x: range(10)})", ), ( "(x * 2 for x in range(10))", - "defstream(x() * 2, {x: lambda: range(10)})", + "defstream(x() * 2, {x: range(10)})", ), ( "(x + 1 for x in items)", - "defstream(x() + 1, {x: lambda: items})", + "defstream(x() + 1, {x: items})", ), # Complex expressions ( "(x ** 2 + 2 * x + 1 for x in range(5))", - "defstream(x() ** 2 + 2 * x() + 1, {x: lambda: range(5)})", + "defstream(x() ** 2 + 2 * x() + 1, {x: range(5)})", ), ( "(f(x) for x in data)", - "defstream(f(x()), {x: lambda: data})", + "defstream(f(x()), {x: data})", ), # Multiple nested loops ( "(x + y for x in range(3) for y in range(4))", - "defstream(x() + y(), {x: lambda: range(3), y: lambda: range(4)})", + "defstream(x() + y(), {x: range(3), y: range(4)})", ), ( "(x * y for x in items1 for y in items2)", - "defstream(x() * y(), {x: lambda: items1, y: lambda: items2})", + "defstream(x() * y(), {x: items1, y: items2})", ), # Nested loops with dependencies ( "(x + y for x in range(3) for y in range(x))", - "defstream(x() + y(), {x: lambda: range(3), y: lambda: range(x())})", + "defstream(x() + y(), {x: range(3), y: range(x())})", ), ( "((x, y) for x in range(3) for y in range(x, 5))", - "defstream((x(), y()), {x: lambda: range(3), y: lambda: range(x(), 5)})", + "defstream((x(), y()), {x: range(3), y: range(x(), 5)})", ), # Triple nested loops ( "(x + y + z for x in range(2) for y in range(2) for z in range(2))", - "defstream(x() + y() + z(), {x: lambda: range(2), y: lambda: range(2), z: lambda: range(2)})", + "defstream(x() + y() + z(), {x: range(2), y: range(2), z: range(2)})", ), # Complex iterators ( "(x for x in [1, 2, 3])", - "defstream(x(), {x: lambda: [1, 2, 3]})", + "defstream(x(), {x: [1, 2, 3]})", ), ( "(x for x in list(range(5)))", - "defstream(x(), {x: lambda: list(range(5))})", + "defstream(x(), {x: list(range(5))})", ), # Expressions with function calls on iterators ( "(x for x in sorted(items))", - "defstream(x(), {x: lambda: sorted(items)})", + "defstream(x(), {x: sorted(items)})", ), # Generator expressions with filters ( @@ -988,7 +988,7 @@ def test_name_to_call_nested_comprehensions(source_src, varnames, expected_src): ), ( "(x + y for x in range(3) if x > 0 for y in range(3))", - "defstream(x() + y(), {x: (x for x in range(3) if x > 0), y: lambda: range(3)})", + "defstream(x() + y(), {x: (x for x in range(3) if x > 0), y: range(3)})", ), ( "(x * 2 for x in items if x > 5)", @@ -996,7 +996,7 @@ def test_name_to_call_nested_comprehensions(source_src, varnames, expected_src): ), ( "(x + y for x in range(5) for y in range(5) if x < y)", - "defstream(x() + y(), {x: lambda: range(5), y: (y for y in range(5) if x() < y)})", + "defstream(x() + y(), {x: range(5), y: (y for y in range(5) if x() < y)})", ), ( "(x for x in range(20) if x % 2 == 0 if x % 3 == 0)", @@ -1005,15 +1005,15 @@ def test_name_to_call_nested_comprehensions(source_src, varnames, expected_src): # Generator expressions with unpacking ( "((x, y) for x, y in pairs)", - "defstream((x(), y()), {(x, y): lambda: pairs})", + "defstream((x(), y()), {(x, y): pairs})", ), ( "(a + b for a, b in zip(list1, list2))", - "defstream(a() + b(), {(a, b): lambda: zip(list1, list2)})", + "defstream(a() + b(), {(a, b): zip(list1, list2)})", ), ( "(x + y + z for x, (y, z) in nested_pairs)", - "defstream(x() + y() + z(), {(x, (y, z)): lambda: nested_pairs})", + "defstream(x() + y() + z(), {(x, (y, z)): nested_pairs})", ), ], ) @@ -1052,51 +1052,51 @@ def test_generator_exp_to_defstream_transformer(genexpr_src, expected_src): # Generator expressions yielding comprehensions ( "([x for x in range(i)] for i in range(5))", - "defstream([x for x in range(i())], {i: lambda: range(5)})", + "defstream([x for x in range(i())], {i: range(5)})", ), ( "({x: x**2 for x in range(i)} for i in range(5))", - "defstream({x: x**2 for x in range(i())}, {i: lambda: range(5)})", + "defstream({x: x**2 for x in range(i())}, {i: range(5)})", ), ( "({x for x in range(i) if x > 2} for i in range(10))", - "defstream({x for x in range(i()) if x > 2}, {i: lambda: range(10)})", + "defstream({x for x in range(i()) if x > 2}, {i: range(10)})", ), # Double nested comprehensions ( "([[x for x in range(j)] for j in range(i)] for i in range(3))", - "defstream([[x for x in range(j)] for j in range(i())], {i: lambda: range(3)})", + "defstream([[x for x in range(j)] for j in range(i())], {i: range(3)})", ), ( "([[x + y for y in range(j)] for j in range(i)] for i in range(3))", - "defstream([[x + y for y in range(j)] for j in range(i())], {i: lambda: range(3)})", + "defstream([[x + y for y in range(j)] for j in range(i())], {i: range(3)})", ), # Generator with comprehension in iterator ( "(x * 2 for x in [y**2 for y in range(5)])", - "defstream(x() * 2, {x: lambda: [y**2 for y in range(5)]})", + "defstream(x() * 2, {x: [y**2 for y in range(5)]})", ), ( "(x for x in {y: y**2 for y in range(3)}.values())", - "defstream(x(), {x: lambda: {y: y**2 for y in range(3)}.values()})", + "defstream(x(), {x: {y: y**2 for y in range(3)}.values()})", ), # Complex expressions with nested calls ( "(sum([x for x in range(i)]) for i in range(5))", - "defstream(sum([x for x in range(i())]), {i: lambda: range(5)})", + "defstream(sum([x for x in range(i())]), {i: range(5)})", ), ( "(max(x for x in range(i + 1)) for i in range(3))", - "defstream(max(defstream(x(), {x: lambda: range(i() + 1)})), {i: lambda: range(3)})", + "defstream(max(defstream(x(), {x: range(i() + 1)})), {i: range(3)})", ), ( "(list(enumerate(x + 1 for x in range(i + 1))) for i in range(3))", - "defstream(list(enumerate(defstream(x() + 1, {x: lambda: range(i() + 1)}))), {i: lambda: range(3)})", + "defstream(list(enumerate(defstream(x() + 1, {x: range(i() + 1)}))), {i: range(3)})", ), # Nested generators in iterator ( "([y for y in range(x)] for x in (z + 1 for z in range(5)))", - "defstream([y for y in range(x())], {x: lambda: defstream(z() + 1, {z: lambda: range(5)})})", + "defstream([y for y in range(x())], {x: defstream(z() + 1, {z: range(5)})})", ), # Complex filters with nested comprehensions ( @@ -1114,25 +1114,25 @@ def test_generator_exp_to_defstream_transformer(genexpr_src, expected_src): ), ( "((x, y, z) for x in range(3) for y in range(x, 5) if y > x for z in range(y) if z < y)", - "defstream((x(), y(), z()), {x: lambda: range(3), y: (y for y in range(x(), 5) if y > x()), z: (z for z in range(y()) if z < y())})", + "defstream((x(), y(), z()), {x: range(3), y: (y for y in range(x(), 5) if y > x()), z: (z for z in range(y()) if z < y())})", ), # Unpacking with nested structures ( "(a + b + c for (a, b), c in [((1, 2), 3), ((4, 5), 6)])", - "defstream(a() + b() + c(), {((a, b), c): lambda: [((1, 2), 3), ((4, 5), 6)]})", + "defstream(a() + b() + c(), {((a, b), c): [((1, 2), 3), ((4, 5), 6)]})", ), ( "(x + sum(lst) for x, lst in [(1, [2, 3]), (4, [5, 6])])", - "defstream(x() + sum(lst()), {(x, lst): lambda: [(1, [2, 3]), (4, [5, 6])]})", + "defstream(x() + sum(lst()), {(x, lst): [(1, [2, 3]), (4, [5, 6])]})", ), # Complex iterators ( "(x for x in sorted([y**2 for y in range(5)]))", - "defstream(x(), {x: lambda: sorted([y**2 for y in range(5)])})", + "defstream(x(), {x: sorted([y**2 for y in range(5)])})", ), ( "(item for sublist in [[1, 2], [3, 4], [5, 6]] for item in sublist)", - "defstream(item(), {sublist: lambda: [[1, 2], [3, 4], [5, 6]], item: lambda: sublist()})", + "defstream(item(), {sublist: [[1, 2], [3, 4], [5, 6]], item: sublist()})", ), # Expressions with method calls ( @@ -1146,52 +1146,52 @@ def test_generator_exp_to_defstream_transformer(genexpr_src, expected_src): # CRITICAL: Generator expressions yielding generator expressions ( "((x for x in range(i)) for i in range(5))", - "defstream(defstream(x(), {x: lambda: range(i())}), {i: lambda: range(5)})", + "defstream(defstream(x(), {x: range(i())}), {i: range(5)})", ), ( "((x * 2 for x in range(i)) for i in range(3))", - "defstream(defstream(x() * 2, {x: lambda: range(i())}), {i: lambda: range(3)})", + "defstream(defstream(x() * 2, {x: range(i())}), {i: range(3)})", ), ( "((x + y for x in range(3) for y in range(x)) for i in range(2))", - "defstream(defstream(x() + y(), {x: lambda: range(3), y: lambda: range(x())}), {i: lambda: range(2)})", + "defstream(defstream(x() + y(), {x: range(3), y: range(x())}), {i: range(2)})", ), # Generator yielding filtered generator ( "((x for x in range(10) if x % 2 == 0) for i in range(3))", - "defstream(defstream(x(), {x: (x for x in range(10) if x % 2 == 0)}), {i: lambda: range(3)})", + "defstream(defstream(x(), {x: (x for x in range(10) if x % 2 == 0)}), {i: range(3)})", ), ( "((x for x in range(i) if x > 0) for i in range(5))", - "defstream(defstream(x(), {x: (x for x in range(i()) if x > 0)}), {i: lambda: range(5)})", + "defstream(defstream(x(), {x: (x for x in range(i()) if x > 0)}), {i: range(5)})", ), # Nested generators with multiple levels ( "((y for y in (x for x in range(i))) for i in range(3))", - "defstream(defstream(y(), {y: lambda: defstream(x(), {x: lambda: range(i())})}), {i: lambda: range(3)})", + "defstream(defstream(y(), {y: defstream(x(), {x: range(i())})}), {i: range(3)})", ), ( "(((x + y for x in range(2)) for y in range(3)) for z in range(4))", - "defstream(defstream(defstream(x() + y(), {x: lambda: range(2)}), {y: lambda: range(3)}), {z: lambda: range(4)})", + "defstream(defstream(defstream(x() + y(), {x: range(2)}), {y: range(3)}), {z: range(4)})", ), # Generator with unpacking yielding generator ( "((x + b for x in range(a)) for a, b in [(2, 3), (4, 5)])", - "defstream(defstream(x() + b(), {x: lambda: range(a())}), {(a, b): lambda: [(2, 3), (4, 5)]})", + "defstream(defstream(x() + b(), {x: range(a())}), {(a, b): [(2, 3), (4, 5)]})", ), # Complex case: generator yielding generator with filters and dependencies ( "((x + y for x in range(i) if x > 0 for y in range(x)) for i in range(5) if i > 2)", - "defstream(defstream(x() + y(), {x: (x for x in range(i()) if x > 0), y: lambda: range(x())}), {i: (i for i in range(5) if i > 2)})", + "defstream(defstream(x() + y(), {x: (x for x in range(i()) if x > 0), y: range(x())}), {i: (i for i in range(5) if i > 2)})", ), # Generator expression yielding sum of generator expression ( "(sum(x for x in range(i)) for i in range(5))", - "defstream(sum(defstream(x(), {x: lambda: range(i())})), {i: lambda: range(5)})", + "defstream(sum(defstream(x(), {x: range(i())})), {i: range(5)})", ), ( "(max(x * 2 for x in range(i) if x > 0) for i in range(10))", - "defstream(max(defstream(x() * 2, {x: (x for x in range(i()) if x > 0)})), {i: lambda: range(10)})", + "defstream(max(defstream(x() * 2, {x: (x for x in range(i()) if x > 0)})), {i: range(10)})", ), ], )