Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions effectful/internals/disassembler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,3 +1833,153 @@ def reconstruct(genexpr: Generator[object, None, None]) -> ast.Expression:
"""
assert inspect.isgenerator(genexpr), "Input must be a generator expression"
return ast.fix_missing_locations(ast.Expression(ensure_ast(genexpr)))


class NameToCall(ast.NodeTransformer):
"""
Transform variable names into calls to those variables.
This transformer replaces occurrences of specified variable names in an AST
with calls to those variables. For example, if the variable name is 'x',
it will replace 'x' with 'x()'.
"""

varnames: set[str]

def __init__(self, varnames: set[str]):
self.varnames = varnames

def visit_Name(self, node: ast.Name) -> ast.Call | ast.Name:
if node.id in self.varnames and isinstance(node.ctx, ast.Load):
return ast.Call(node, args=[], keywords=[])
elif node.id in self.varnames and isinstance(node.ctx, ast.Store):
# If it's a store context, we don't want to transform it into a call
return ast.Name(id=node.id, ctx=ast.Load())
else:
return node


class NameStoreToLoad(ast.NodeTransformer):
"""
Transform variable names in store context to load context.
This transformer changes all variable names in store context (e.g., assignments)
to load context, effectively treating them as read-only variables.
"""

def visit_Name(self, node: ast.Name) -> ast.Name:
return ast.Name(id=node.id, ctx=ast.Load())

def visit_Tuple(self, node: ast.Tuple) -> ast.Tuple:
return ast.Tuple(elts=[self.visit(elt) for elt in node.elts], ctx=ast.Load())


class GeneratorExpToDefstream(ast.NodeTransformer):
"""
Transform generator expressions into calls to `defstream`.
This transformer converts generator expressions of the form:

(expr for var in iter)
into calls to `defstream`:
defstream(expr, {var: lambda: iter})

It supports:
- Multiple nested loops
- Filter conditions (if clauses) - converted to filtered generator expressions
- Tuple unpacking in loop variables
- Variables are correctly transformed into calls within the expression and iterators

Examples:
>>> import ast
>>> source = "(x * 2 for x in range(10))"
>>> tree = ast.parse(source, mode='eval')
>>> transformer = GeneratorExpTodefstream()
>>> transformed = transformer.visit(tree)
>>> ast.unparse(transformed)
'defstream(x() * 2, {x: lambda: range(10)})'

>>> source = "(x for x in range(10) if x % 2 == 0)"
>>> tree = ast.parse(source, mode='eval')
>>> transformed = transformer.visit(tree)
>>> ast.unparse(transformed)
'defstream(x(), {x: (x for x in range(10) if x % 2 == 0)})'

>>> source = "((x, y) for x, y in pairs)"
>>> tree = ast.parse(source, mode='eval')
>>> transformed = transformer.visit(tree)
>>> ast.unparse(transformed)
'defstream((x(), y()), {(x, y): lambda: pairs})'

"""

def visit_GeneratorExp(self, node: ast.GeneratorExp) -> ast.Call:
# Check for unsupported features
for gen in node.generators:
if not isinstance(gen.target, ast.Name) and not isinstance(
gen.target, ast.Tuple
):
raise NotImplementedError(
f"Unsupported target type: {type(gen.target)}"
)

# Get all variable names from all targets (including unpacked tuples)
def get_names_from_target(target):
if isinstance(target, ast.Name):
return [target.id]
elif isinstance(target, ast.Tuple):
names = []
for elt in target.elts:
names.extend(get_names_from_target(elt))
return names
else:
raise NotImplementedError(
f"Unsupported target type in unpacking: {type(target)}"
)

streams = ast.Dict(keys=[], values=[])
all_var_names: set[str] = set()

for gen in node.generators:
# Collect variable names from previous generators
prev_var_names = set(all_var_names)

# Add current target variables to the set
target_names = get_names_from_target(gen.target)
all_var_names.update(target_names)

# Create the value for this generator
value: ast.expr # TODO : Specify type more precisely
if gen.ifs:
# If there are filters, create a generator expression for the filtered iterator
# Note: In the filter conditions, we need to transform previous loop variables
# but NOT the current loop variable
filtered_gen = ast.GeneratorExp(
elt=gen.target if isinstance(gen.target, ast.Name) else gen.target,
generators=[
ast.comprehension(
target=gen.target,
iter=self.visit(NameToCall(prev_var_names).visit(gen.iter)),
ifs=[
self.visit(NameToCall(prev_var_names).visit(if_clause))
for if_clause in gen.ifs
],
is_async=gen.is_async,
)
],
)
value = filtered_gen
else:
# No filters, create an expression
value = self.visit(NameToCall(prev_var_names).visit(gen.iter))

streams.keys.append(NameStoreToLoad().visit(gen.target))
streams.values.append(value)

# Transform the body expression
# First apply NameToCall, then recursively visit for nested generators
body: ast.expr = NameToCall(all_var_names).visit(node.elt)
body = self.visit(body) # Recursively transform nested generator expressions

return ast.Call(
func=ast.Name(id="defstream", ctx=ast.Load()),
args=[body, streams],
keywords=[],
)
20 changes: 20 additions & 0 deletions effectful/ops/syntax.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import collections.abc
import dataclasses
import functools
Expand Down Expand Up @@ -1141,6 +1142,25 @@ def __next__(self: collections.abc.Iterator[T]) -> T:
next_ = _IteratorTerm.__next__


@defterm.register(types.GeneratorType)
def _[T](genexpr: types.GeneratorType[T, None, None]) -> Expr[Iterable[T]]:
from effectful.internals.disassembler import GeneratorExpToDefstream, reconstruct

genexpr_ast = reconstruct(genexpr)
forexpr_ast = GeneratorExpToDefstream().visit(genexpr_ast)
forexpr_name = ".".join(genexpr.gi_code.co_name.split(".")[:-1] + ["<defstream>"])
forexpr_code = compile(
ast.fix_missing_locations(forexpr_ast),
filename=forexpr_name,
mode="eval",
)
return eval(
forexpr_code,
genexpr.gi_frame.f_globals | {"defstream": defstream},
genexpr.gi_frame.f_locals, # TODO infer types and construct stream variables
)


def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool:
"""Syntactic equality, ignoring the interpretation of the terms.

Expand Down
Loading