diff --git a/effectful/internals/disassembler.py b/effectful/internals/disassembler.py index c12a6c90..b33ae84d 100644 --- a/effectful/internals/disassembler.py +++ b/effectful/internals/disassembler.py @@ -1833,3 +1833,153 @@ def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression: """ assert inspect.isgenerator(genexpr), "Input must be a generator expression" return ast.fix_missing_locations(ast.Expression(ensure_ast(genexpr))) + + +class NameToCall(ast.NodeTransformer): + """ + Transform variable names into calls to those variables. + This transformer replaces occurrences of specified variable names in an AST + with calls to those variables. For example, if the variable name is 'x', + it will replace 'x' with 'x()'. + """ + + varnames: set[str] + + def __init__(self, varnames: set[str]): + self.varnames = varnames + + def visit_Name(self, node: ast.Name) -> ast.Call | ast.Name: + if node.id in self.varnames and isinstance(node.ctx, ast.Load): + return ast.Call(node, args=[], keywords=[]) + elif node.id in self.varnames and isinstance(node.ctx, ast.Store): + # If it's a store context, we don't want to transform it into a call + return ast.Name(id=node.id, ctx=ast.Load()) + else: + return node + + +class NameStoreToLoad(ast.NodeTransformer): + """ + Transform variable names in store context to load context. + This transformer changes all variable names in store context (e.g., assignments) + to load context, effectively treating them as read-only variables. + """ + + def visit_Name(self, node: ast.Name) -> ast.Name: + return ast.Name(id=node.id, ctx=ast.Load()) + + def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple: + return ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=ast.Load()) + + +class GeneratorExpToDefstream(ast.NodeTransformer): + """ + Transform generator expressions into calls to `defstream`. + This transformer converts generator expressions of the form: + + (expr for var in iter) + into calls to `defstream`: + defstream(expr, {var: lambda: iter}) + + It supports: + - Multiple nested loops + - Filter conditions (if clauses) - converted to filtered generator expressions + - Tuple unpacking in loop variables + - Variables are correctly transformed into calls within the expression and iterators + + Examples: + >>> import ast + >>> source = "(x * 2 for x in range(10))" + >>> tree = ast.parse(source, mode='eval') + >>> transformer = GeneratorExpTodefstream() + >>> transformed = transformer.visit(tree) + >>> ast.unparse(transformed) + 'defstream(x() * 2, {x: lambda: range(10)})' + + >>> source = "(x for x in range(10) if x % 2 == 0)" + >>> tree = ast.parse(source, mode='eval') + >>> transformed = transformer.visit(tree) + >>> ast.unparse(transformed) + 'defstream(x(), {x: (x for x in range(10) if x % 2 == 0)})' + + >>> source = "((x, y) for x, y in pairs)" + >>> tree = ast.parse(source, mode='eval') + >>> transformed = transformer.visit(tree) + >>> ast.unparse(transformed) + 'defstream((x(), y()), {(x, y): lambda: pairs})' + + """ + + def visit_GeneratorExp(self, node: ast.GeneratorExp) -> ast.Call: + # Check for unsupported features + for gen in node.generators: + if not isinstance(gen.target, ast.Name) and not isinstance( + gen.target, ast.Tuple + ): + raise NotImplementedError( + f"Unsupported target type: {type(gen.target)}" + ) + + # Get all variable names from all targets (including unpacked tuples) + def get_names_from_target(target): + if isinstance(target, ast.Name): + return [target.id] + elif isinstance(target, ast.Tuple): + names = [] + for elt in target.elts: + names.extend(get_names_from_target(elt)) + return names + else: + raise NotImplementedError( + f"Unsupported target type in unpacking: {type(target)}" + ) + + streams = ast.Dict(keys=[], values=[]) + all_var_names: set[str] = set() + + for gen in node.generators: + # Collect variable names from previous generators + prev_var_names = set(all_var_names) + + # Add current target variables to the set + target_names = get_names_from_target(gen.target) + all_var_names.update(target_names) + + # Create the value for this generator + value: ast.expr # TODO : Specify type more precisely + if gen.ifs: + # If there are filters, create a generator expression for the filtered iterator + # Note: In the filter conditions, we need to transform previous loop variables + # but NOT the current loop variable + filtered_gen = ast.GeneratorExp( + elt=gen.target if isinstance(gen.target, ast.Name) else gen.target, + generators=[ + ast.comprehension( + target=gen.target, + iter=self.visit(NameToCall(prev_var_names).visit(gen.iter)), + ifs=[ + self.visit(NameToCall(prev_var_names).visit(if_clause)) + for if_clause in gen.ifs + ], + is_async=gen.is_async, + ) + ], + ) + value = filtered_gen + else: + # No filters, create an expression + value = self.visit(NameToCall(prev_var_names).visit(gen.iter)) + + streams.keys.append(NameStoreToLoad().visit(gen.target)) + streams.values.append(value) + + # Transform the body expression + # First apply NameToCall, then recursively visit for nested generators + body: ast.expr = NameToCall(all_var_names).visit(node.elt) + body = self.visit(body) # Recursively transform nested generator expressions + + return ast.Call( + func=ast.Name(id="defstream", ctx=ast.Load()), + args=[body, streams], + keywords=[], + ) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 99d6edaf..5b86e6a8 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1,3 +1,4 @@ +import ast import collections.abc import dataclasses import functools @@ -1141,6 +1142,25 @@ def __next__(self: collections.abc.Iterator[T]) -> T: next_ = _IteratorTerm.__next__ +@defterm.register(types.GeneratorType) +def _[T](genexpr: types.GeneratorType[T, None, None]) -> Expr[Iterable[T]]: + from effectful.internals.disassembler import GeneratorExpToDefstream, reconstruct + + genexpr_ast = reconstruct(genexpr) + forexpr_ast = GeneratorExpToDefstream().visit(genexpr_ast) + forexpr_name = ".".join(genexpr.gi_code.co_name.split(".")[:-1] + [""]) + forexpr_code = compile( + ast.fix_missing_locations(forexpr_ast), + filename=forexpr_name, + mode="eval", + ) + return eval( + forexpr_code, + genexpr.gi_frame.f_globals | {"defstream": defstream}, + genexpr.gi_frame.f_locals, # TODO infer types and construct stream variables + ) + + def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool: """Syntactic equality, ignoring the interpretation of the terms. diff --git a/tests/test_internals_disassembler.py b/tests/test_internals_disassembler.py index facab839..5da3e998 100644 --- a/tests/test_internals_disassembler.py +++ b/tests/test_internals_disassembler.py @@ -1,16 +1,20 @@ import ast -from types import GeneratorType -from typing import Any +from collections.abc import Generator import pytest import tree -from effectful.internals.disassembler import reconstruct +from effectful.internals.disassembler import ( + GeneratorExpToDefstream, + NameToCall, + ensure_ast, + reconstruct, +) def compile_and_eval( node: ast.expr | ast.Expression, globals_dict: dict | None = None -) -> Any: +) -> object: """Compile an AST node and evaluate it.""" if globals_dict is None: globals_dict = {} @@ -27,11 +31,11 @@ def compile_and_eval( return eval(code, globals_dict) -def materialize(genexpr: GeneratorType) -> tree.Structure: +def materialize(genexpr: Generator[object, None, None]) -> tree.Structure: """Materialize a nested generator expression to a nested list.""" def _materialize(genexpr): - if isinstance(genexpr, GeneratorType): + if isinstance(genexpr, Generator): return tree.map_structure(_materialize, list(genexpr)) elif tree.is_nested(genexpr): return tree.map_structure(_materialize, genexpr) @@ -42,7 +46,9 @@ def _materialize(genexpr): def assert_ast_equivalent( - genexpr: GeneratorType, reconstructed_ast: ast.AST, globals_dict: dict | None = None + genexpr: Generator[object, None, None], + reconstructed_ast: ast.AST, + globals_dict: dict | None = None, ): """Assert that a reconstructed AST produces the same results as the original generator.""" # Check AST structure @@ -70,6 +76,7 @@ def assert_ast_equivalent( # Compile and evaluate the reconstructed AST reconstructed_gen = compile_and_eval(reconstructed_ast, globals_dict) + assert isinstance(reconstructed_gen, Generator) reconstructed_list = materialize(reconstructed_gen) assert reconstructed_list == original_list, ( f"AST produced {reconstructed_list}, expected {original_list}" @@ -647,8 +654,6 @@ def test_complex_scenarios(genexpr, globals_dict): ) def test_ensure_ast(value, expected_str): """Test that ensure_ast correctly converts various values to AST nodes.""" - from effectful.internals.disassembler import ensure_ast - result = ensure_ast(value) # Compare the unparsed strings @@ -717,3 +722,496 @@ def test_comp_lambda_copy(): assert ast.unparse(original_result) == ast.unparse(deep_copied_result) assert type(original_result) == type(deep_copied_result) + + +# ============================================================================ +# AST TRANSFORMER TESTS +# ============================================================================ + + +@pytest.mark.parametrize( + "source_src,varnames,expected_src", + [ + # Simple name replacement + ("x", {"x"}, "x()"), + ("x + y", {"x"}, "x() + y"), + ("x + y", {"y"}, "x + y()"), + ("x + y", {"x", "y"}, "x() + y()"), + # Names in different contexts (Note: Store/Del contexts work differently) + # These would need to be parsed as statements, not expressions + # Skipping assignment and del tests as they can't be parsed in eval mode + # Complex expressions + ("x * 2 + y / 3", {"x", "y"}, "x() * 2 + y() / 3"), + ("func(x, y, z)", {"x", "z"}, "func(x(), y, z())"), + # Nested expressions + ("x + (y * z)", {"x", "y", "z"}, "x() + (y() * z())"), + ("[x, y, z]", {"x", "y"}, "[x(), y(), z]"), + ("{'a': x, 'b': y}", {"x", "y"}, "{'a': x(), 'b': y()}"), + # Attribute access + ("x.attr", {"x"}, "x().attr"), + ("x.method()", {"x"}, "x().method()"), + ("obj.x", {"x"}, "obj.x"), # x is an attribute, not a variable + # Subscript operations + ("x[0]", {"x"}, "x()[0]"), + ("arr[x]", {"x"}, "arr[x()]"), + ("x[y]", {"x", "y"}, "x()[y()]"), + # Function calls + ("f(x)", {"f"}, "f()(x)"), + ("f(x)", {"x"}, "f(x())"), + ("f(x, y)", {"f", "x", "y"}, "f()(x(), y())"), + # Lambda expressions + ("lambda a: x + a", {"x"}, "lambda a: x() + a"), + ( + "lambda x: x + y", + {"x", "y"}, + "lambda x: x() + y()", + ), # Transformer doesn't check scope + # Comprehensions + ("[x for i in range(3)]", {"x"}, "[x() for i in range(3)]"), + ("[i for i in x]", {"x"}, "[i for i in x()]"), + ( + "[x for x in range(3)]", + {"x"}, + "[x() for x in range(3)]", + ), # Transformer doesn't check scope + ("(x + i for i in range(3))", {"x"}, "(x() + i for i in range(3))"), + # Multiple occurrences + ("x + x * x", {"x"}, "x() + x() * x()"), + # Names not in the set should not be transformed + ("x + y + z", {"x"}, "x() + y + z"), + ("a + b + c", {"x", "y", "z"}, "a + b + c"), + ], +) +def test_name_to_call_transformer(source_src, varnames, expected_src): + """Test NameToCall transformer converts specified names to function calls.""" + # Parse the source code + source_ast = ast.parse(source_src, mode="eval") + expected_ast = ast.parse(expected_src, mode="eval") + + # Apply the transformer + transformer = NameToCall(varnames) + transformed_ast = transformer.visit(source_ast) + + # Compare the unparsed strings + transformed_src = ast.unparse(transformed_ast) + expected_unparsed = ast.unparse(expected_ast) + + assert transformed_src == expected_unparsed, ( + f"Transformation failed:\n" + f"Input: {source_src}\n" + f"Varnames: {varnames}\n" + f"Expected: {expected_unparsed}\n" + f"Got: {transformed_src}" + ) + + +@pytest.mark.parametrize( + "source_src,varnames,expected_src", + [ + # Nested comprehensions in body + ("[x for x in range(i)]", {"i"}, "[x for x in range(i())]"), + ( + "[[x for x in range(j)] for j in range(i)]", + {"i"}, + "[[x for x in range(j)] for j in range(i())]", + ), + ( + "[[x for x in range(j)] for j in range(i)]", + {"i", "j"}, + "[[x for x in range(j())] for j in range(i())]", + ), + # Complex nested comprehensions + ("sum(x for x in range(i))", {"i"}, "sum((x for x in range(i())))"), + ("sum(x for x in range(i))", {"x"}, "sum((x() for x in range(i)))"), + ("sum(x for x in range(i))", {"i", "x"}, "sum((x() for x in range(i())))"), + # Names in nested comprehensions + ("[y for y in range(x)]", {"x"}, "[y for y in range(x())]"), + ("[y for y in range(x)]", {"y"}, "[y() for y in range(x)]"), + ( + "[[y + z for y in range(x)] for z in range(x)]", + {"x"}, + "[[y + z for y in range(x())] for z in range(x())]", + ), + ( + "[[y + z for y in range(x)] for z in range(x)]", + {"x", "y", "z"}, + "[[y() + z() for y in range(x())] for z in range(x())]", + ), + # Mixed comprehension types with names + ( + "{x: [y for y in range(x)] for x in range(n)}", + {"n"}, + "{x: [y for y in range(x)] for x in range(n())}", + ), + ( + "{i: [j for j in range(i)] for i in range(n)}", + {"i", "j", "n"}, + "{i(): [j() for j in range(i())] for i in range(n())}", + ), + # Names in different parts of comprehensions + ("[f(x) for x in data]", {"f"}, "[f()(x) for x in data]"), + ("[f(x) for x in data]", {"data"}, "[f(x) for x in data()]"), + ("[f(x) for x in data]", {"f", "data"}, "[f()(x) for x in data()]"), + ( + "[f(x) for x in data if pred(x)]", + {"f", "pred", "data"}, + "[f()(x) for x in data() if pred()(x)]", + ), + # Nested function calls + ("f(g(x))", {"f"}, "f()(g(x))"), + ("f(g(x))", {"g"}, "f(g()(x))"), + ("f(g(x))", {"x"}, "f(g(x()))"), + ("f(g(h(x)))", {"f", "g", "h", "x"}, "f()(g()(h()(x())))"), + # Complex expressions with comprehensions + ( + "sum([x * y for x in range(a) for y in range(b)])", + {"a", "b"}, + "sum([x * y for x in range(a()) for y in range(b())])", + ), + ( + "max(x + y for x in items1 for y in items2)", + {"items1", "items2"}, + "max((x + y for x in items1() for y in items2()))", + ), + # Boolean operations with names + ("x and y", {"x"}, "x() and y"), + ("x and y", {"x", "y"}, "x() and y()"), + ("x or y or z", {"x", "z"}, "x() or y or z()"), + ("not x", {"x"}, "not x()"), + # Ternary expressions + ("x if cond else y", {"cond"}, "x if cond() else y"), + ("x if cond else y", {"x", "y", "cond"}, "x() if cond() else y()"), + # Names in slice operations + ("arr[start:end]", {"start", "end"}, "arr[start():end()]"), + ("arr[i:j:k]", {"i", "j", "k"}, "arr[i():j():k()]"), + ("matrix[i][j]", {"i", "j"}, "matrix[i()][j()]"), + # Dict/set comprehensions + ("{x: y for x, y in pairs}", {"pairs"}, "{x: y for (x, y) in pairs()}"), + ("{f(x) for x in items}", {"f", "items"}, "{f()(x) for x in items()}"), + # Nested comprehensions with filters + ( + "[x for x in [y for y in range(n) if y > m] if x < k]", + {"n", "m", "k"}, + "[x for x in [y for y in range(n()) if y > m()] if x < k()]", + ), + ], +) +def test_name_to_call_nested_comprehensions(source_src, varnames, expected_src): + """Test NameToCall transformer with nested comprehensions and complex expressions.""" + # Parse the source code + source_ast = ast.parse(source_src, mode="eval") + expected_ast = ast.parse(expected_src, mode="eval") + + # Apply the transformer + transformer = NameToCall(varnames) + transformed_ast = transformer.visit(source_ast) + + # Compare the unparsed strings + transformed_src = ast.unparse(transformed_ast) + expected_unparsed = ast.unparse(expected_ast) + + assert transformed_src == expected_unparsed, ( + f"Transformation failed:\n" + f"Input: {source_src}\n" + f"Varnames: {varnames}\n" + f"Expected: {expected_unparsed}\n" + f"Got: {transformed_src}" + ) + + +@pytest.mark.parametrize( + "genexpr_src,expected_src", + [ + # Simple generator expressions + ( + "(x for x in range(10))", + "defstream(x(), {x: range(10)})", + ), + ( + "(x * 2 for x in range(10))", + "defstream(x() * 2, {x: range(10)})", + ), + ( + "(x + 1 for x in items)", + "defstream(x() + 1, {x: items})", + ), + # Complex expressions + ( + "(x ** 2 + 2 * x + 1 for x in range(5))", + "defstream(x() ** 2 + 2 * x() + 1, {x: range(5)})", + ), + ( + "(f(x) for x in data)", + "defstream(f(x()), {x: data})", + ), + # Multiple nested loops + ( + "(x + y for x in range(3) for y in range(4))", + "defstream(x() + y(), {x: range(3), y: range(4)})", + ), + ( + "(x * y for x in items1 for y in items2)", + "defstream(x() * y(), {x: items1, y: items2})", + ), + # Nested loops with dependencies + ( + "(x + y for x in range(3) for y in range(x))", + "defstream(x() + y(), {x: range(3), y: range(x())})", + ), + ( + "((x, y) for x in range(3) for y in range(x, 5))", + "defstream((x(), y()), {x: range(3), y: range(x(), 5)})", + ), + # Triple nested loops + ( + "(x + y + z for x in range(2) for y in range(2) for z in range(2))", + "defstream(x() + y() + z(), {x: range(2), y: range(2), z: range(2)})", + ), + # Complex iterators + ( + "(x for x in [1, 2, 3])", + "defstream(x(), {x: [1, 2, 3]})", + ), + ( + "(x for x in list(range(5)))", + "defstream(x(), {x: list(range(5))})", + ), + # Expressions with function calls on iterators + ( + "(x for x in sorted(items))", + "defstream(x(), {x: sorted(items)})", + ), + # Generator expressions with filters + ( + "(x for x in range(10) if x % 2 == 0)", + "defstream(x(), {x: (x for x in range(10) if x % 2 == 0)})", + ), + ( + "(x + y for x in range(3) if x > 0 for y in range(3))", + "defstream(x() + y(), {x: (x for x in range(3) if x > 0), y: range(3)})", + ), + ( + "(x * 2 for x in items if x > 5)", + "defstream(x() * 2, {x: (x for x in items if x > 5)})", + ), + ( + "(x + y for x in range(5) for y in range(5) if x < y)", + "defstream(x() + y(), {x: range(5), y: (y for y in range(5) if x() < y)})", + ), + ( + "(x for x in range(20) if x % 2 == 0 if x % 3 == 0)", + "defstream(x(), {x: (x for x in range(20) if x % 2 == 0 if x % 3 == 0)})", + ), + # Generator expressions with unpacking + ( + "((x, y) for x, y in pairs)", + "defstream((x(), y()), {(x, y): pairs})", + ), + ( + "(a + b for a, b in zip(list1, list2))", + "defstream(a() + b(), {(a, b): zip(list1, list2)})", + ), + ( + "(x + y + z for x, (y, z) in nested_pairs)", + "defstream(x() + y() + z(), {(x, (y, z)): nested_pairs})", + ), + ], +) +def test_generator_exp_to_defstream_transformer(genexpr_src, expected_src): + """Test GeneratorExpToDefstream transformer converts generator expressions to defstream calls.""" + # Parse the source code + source_ast = ast.parse(genexpr_src, mode="eval") + + # Apply the transformer + transformer = GeneratorExpToDefstream() + + # For test cases that should raise NotImplementedError + if expected_src == "NOT_APPLICABLE": + # The xfail marker will handle the exception + transformed_ast = transformer.visit(source_ast) + return + + transformed_ast = transformer.visit(source_ast) + expected_ast = ast.parse(expected_src, mode="eval") + + # Compare the unparsed strings + transformed_src = ast.unparse(transformed_ast) + expected_unparsed = ast.unparse(expected_ast) + + assert transformed_src == expected_unparsed, ( + f"Transformation failed:\n" + f"Input: {genexpr_src}\n" + f"Expected: {expected_unparsed}\n" + f"Got: {transformed_src}" + ) + + +@pytest.mark.parametrize( + "genexpr_src,expected_src", + [ + # Generator expressions yielding comprehensions + ( + "([x for x in range(i)] for i in range(5))", + "defstream([x for x in range(i())], {i: range(5)})", + ), + ( + "({x: x**2 for x in range(i)} for i in range(5))", + "defstream({x: x**2 for x in range(i())}, {i: range(5)})", + ), + ( + "({x for x in range(i) if x > 2} for i in range(10))", + "defstream({x for x in range(i()) if x > 2}, {i: range(10)})", + ), + # Double nested comprehensions + ( + "([[x for x in range(j)] for j in range(i)] for i in range(3))", + "defstream([[x for x in range(j)] for j in range(i())], {i: range(3)})", + ), + ( + "([[x + y for y in range(j)] for j in range(i)] for i in range(3))", + "defstream([[x + y for y in range(j)] for j in range(i())], {i: range(3)})", + ), + # Generator with comprehension in iterator + ( + "(x * 2 for x in [y**2 for y in range(5)])", + "defstream(x() * 2, {x: [y**2 for y in range(5)]})", + ), + ( + "(x for x in {y: y**2 for y in range(3)}.values())", + "defstream(x(), {x: {y: y**2 for y in range(3)}.values()})", + ), + # Complex expressions with nested calls + ( + "(sum([x for x in range(i)]) for i in range(5))", + "defstream(sum([x for x in range(i())]), {i: range(5)})", + ), + ( + "(max(x for x in range(i + 1)) for i in range(3))", + "defstream(max(defstream(x(), {x: range(i() + 1)})), {i: range(3)})", + ), + ( + "(list(enumerate(x + 1 for x in range(i + 1))) for i in range(3))", + "defstream(list(enumerate(defstream(x() + 1, {x: range(i() + 1)}))), {i: range(3)})", + ), + # Nested generators in iterator + ( + "([y for y in range(x)] for x in (z + 1 for z in range(5)))", + "defstream([y for y in range(x())], {x: defstream(z() + 1, {z: range(5)})})", + ), + # Complex filters with nested comprehensions + ( + "(x for x in range(10) if x in [y**2 for y in range(5)])", + "defstream(x(), {x: (x for x in range(10) if x in [y**2 for y in range(5)])})", + ), + ( + "([x, y] for x in range(3) if x > 0 for y in range(3) if y in [z for z in range(x)])", + "defstream([x(), y()], {x: (x for x in range(3) if x > 0), y: (y for y in range(3) if y in [z for z in range(x())])})", + ), + # Multiple filters with dependencies + ( + "(x + y for x in range(5) if x > 1 for y in range(x) if y < x - 1)", + "defstream(x() + y(), {x: (x for x in range(5) if x > 1), y: (y for y in range(x()) if y < x() - 1)})", + ), + ( + "((x, y, z) for x in range(3) for y in range(x, 5) if y > x for z in range(y) if z < y)", + "defstream((x(), y(), z()), {x: range(3), y: (y for y in range(x(), 5) if y > x()), z: (z for z in range(y()) if z < y())})", + ), + # Unpacking with nested structures + ( + "(a + b + c for (a, b), c in [((1, 2), 3), ((4, 5), 6)])", + "defstream(a() + b() + c(), {((a, b), c): [((1, 2), 3), ((4, 5), 6)]})", + ), + ( + "(x + sum(lst) for x, lst in [(1, [2, 3]), (4, [5, 6])])", + "defstream(x() + sum(lst()), {(x, lst): [(1, [2, 3]), (4, [5, 6])]})", + ), + # Complex iterators + ( + "(x for x in sorted([y**2 for y in range(5)]))", + "defstream(x(), {x: sorted([y**2 for y in range(5)])})", + ), + ( + "(item for sublist in [[1, 2], [3, 4], [5, 6]] for item in sublist)", + "defstream(item(), {sublist: [[1, 2], [3, 4], [5, 6]], item: sublist()})", + ), + # Expressions with method calls + ( + "(s.upper() for s in ['hello', 'world'] if s.startswith('h'))", + "defstream(s().upper(), {s: (s for s in ['hello', 'world'] if s.startswith('h'))})", + ), + ( + "(obj.value for obj in objects if hasattr(obj, 'value'))", + "defstream(obj().value, {obj: (obj for obj in objects if hasattr(obj, 'value'))})", + ), + # CRITICAL: Generator expressions yielding generator expressions + ( + "((x for x in range(i)) for i in range(5))", + "defstream(defstream(x(), {x: range(i())}), {i: range(5)})", + ), + ( + "((x * 2 for x in range(i)) for i in range(3))", + "defstream(defstream(x() * 2, {x: range(i())}), {i: range(3)})", + ), + ( + "((x + y for x in range(3) for y in range(x)) for i in range(2))", + "defstream(defstream(x() + y(), {x: range(3), y: range(x())}), {i: range(2)})", + ), + # Generator yielding filtered generator + ( + "((x for x in range(10) if x % 2 == 0) for i in range(3))", + "defstream(defstream(x(), {x: (x for x in range(10) if x % 2 == 0)}), {i: range(3)})", + ), + ( + "((x for x in range(i) if x > 0) for i in range(5))", + "defstream(defstream(x(), {x: (x for x in range(i()) if x > 0)}), {i: range(5)})", + ), + # Nested generators with multiple levels + ( + "((y for y in (x for x in range(i))) for i in range(3))", + "defstream(defstream(y(), {y: defstream(x(), {x: range(i())})}), {i: range(3)})", + ), + ( + "(((x + y for x in range(2)) for y in range(3)) for z in range(4))", + "defstream(defstream(defstream(x() + y(), {x: range(2)}), {y: range(3)}), {z: range(4)})", + ), + # Generator with unpacking yielding generator + ( + "((x + b for x in range(a)) for a, b in [(2, 3), (4, 5)])", + "defstream(defstream(x() + b(), {x: range(a())}), {(a, b): [(2, 3), (4, 5)]})", + ), + # Complex case: generator yielding generator with filters and dependencies + ( + "((x + y for x in range(i) if x > 0 for y in range(x)) for i in range(5) if i > 2)", + "defstream(defstream(x() + y(), {x: (x for x in range(i()) if x > 0), y: range(x())}), {i: (i for i in range(5) if i > 2)})", + ), + # Generator expression yielding sum of generator expression + ( + "(sum(x for x in range(i)) for i in range(5))", + "defstream(sum(defstream(x(), {x: range(i())})), {i: range(5)})", + ), + ( + "(max(x * 2 for x in range(i) if x > 0) for i in range(10))", + "defstream(max(defstream(x() * 2, {x: (x for x in range(i()) if x > 0)})), {i: range(10)})", + ), + ], +) +def test_generator_exp_to_defstream_nested_comprehensions(genexpr_src, expected_src): + """Test GeneratorExpToDefstream transformer with nested comprehensions and complex expressions.""" + # Parse the source code + source_ast = ast.parse(genexpr_src, mode="eval") + + # Apply the transformer + transformer = GeneratorExpToDefstream() + transformed_ast = transformer.visit(source_ast) + expected_ast = ast.parse(expected_src, mode="eval") + + # Compare the unparsed strings + transformed_src = ast.unparse(transformed_ast) + expected_unparsed = ast.unparse(expected_ast) + + assert transformed_src == expected_unparsed, ( + f"Transformation failed:\n" + f"Input: {genexpr_src}\n" + f"Expected: {expected_unparsed}\n" + f"Got: {transformed_src}" + ) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index d4e06de1..bef1bae7 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -549,3 +549,11 @@ def test_defstream_1(): # assert isinstance(tm_iter_next, numbers.Number) # TODO # assert issubclass(typeof(tm_iter_next), numbers.Number) assert tm_iter_next.op is next_ + + +def test_defterm_genexpr(): + xs = (x + 1 for x in range(5)) + + tm = defterm(xs) + assert isinstance(tm, Term) + assert tm.op is defstream