diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 9fea27d6..16154bb7 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -1,3 +1,4 @@ +import collections.abc import contextlib import functools import types @@ -5,8 +6,6 @@ from collections.abc import Callable from typing import Any -import tree - from effectful.ops.syntax import deffn, defop from effectful.ops.types import Expr, Interpretation, Operation, Term @@ -68,7 +67,7 @@ def call[**P, T](fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: } with handler(subs): return evaluate(body) - elif not any(isinstance(a, Term) for a in tree.flatten((fn, args, kwargs))): + elif not fvsof((fn, args, kwargs)): return fn(*args, **kwargs) else: raise NotImplementedError @@ -246,18 +245,37 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T] 6 """ - if intp is None: - from effectful.internals.runtime import get_interpretation + from effectful.internals.runtime import get_interpretation, interpreter - intp = get_interpretation() + if intp is not None: + return interpreter(intp)(evaluate)(expr) if isinstance(expr, Term): - (args, kwargs) = tree.map_structure( - functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs) - ) - return apply.__default_rule__(intp, expr.op, *args, **kwargs) - elif tree.is_nested(expr): - return tree.map_structure(functools.partial(evaluate, intp=intp), expr) + args = tuple(evaluate(arg) for arg in expr.args) + kwargs = {k: evaluate(v) for k, v in expr.kwargs.items()} + return expr.op(*args, **kwargs) + elif isinstance(expr, Operation): + op_intp = get_interpretation().get(expr, expr) + return op_intp if isinstance(op_intp, Operation) else expr # type: ignore + elif isinstance(expr, collections.abc.Mapping): + if isinstance(expr, collections.defaultdict): + return type(expr)(expr.default_factory, evaluate(tuple(expr.items()))) # type: ignore + elif isinstance(expr, types.MappingProxyType): + return type(expr)(dict(evaluate(tuple(expr.items())))) # type: ignore + else: + return type(expr)(evaluate(tuple(expr.items()))) # type: ignore + elif isinstance(expr, collections.abc.Sequence): + if isinstance(expr, str | bytes): + return typing.cast(T, expr) # mypy doesnt like ignore here, so we use cast + else: + return type(expr)(evaluate(item) for item in expr) # type: ignore + elif isinstance(expr, collections.abc.Set): + if isinstance(expr, collections.abc.ItemsView | collections.abc.KeysView): + return {evaluate(item) for item in expr} # type: ignore + else: + return type(expr)(evaluate(item) for item in expr) # type: ignore + elif isinstance(expr, collections.abc.ValuesView): + return [evaluate(item) for item in expr] # type: ignore else: return expr diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 8a39eb9e..e97f1816 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -8,8 +8,6 @@ from collections.abc import Callable, Iterable, Mapping from typing import Annotated, Concatenate -import tree - from effectful.ops.types import Annotation, Expr, Operation, Term @@ -355,7 +353,7 @@ def analyze(self, bound_sig: inspect.BoundArguments) -> frozenset[Operation]: else: param_bound_vars = {param_value} elif param_ordinal: # Only process if there's a Scoped annotation - # We can't use tree.flatten here because we want to be able + # We can't use flatten here because we want to be able # to see dict keys def extract_operations(obj): if isinstance(obj, Operation): @@ -662,7 +660,9 @@ def func() -> t: # type: ignore def _[**P, T](t: Callable[P, T], *, name: str | None = None) -> Operation[P, T]: @functools.wraps(t) def func(*args, **kwargs): - if not any(isinstance(a, Term) for a in tree.flatten((args, kwargs))): + from effectful.ops.semantics import fvsof + + if not fvsof((args, kwargs)): return t(*args, **kwargs) else: raise NotImplementedError @@ -872,18 +872,6 @@ def defterm[T](__dispatch: Callable[[type], Callable[[T], Expr[T]]], value: T): return __dispatch(type(value))(value) -def _map_structure_and_keys(func, structure): - def _map_value(value): - if isinstance(value, dict): - return {func(k): v for k, v in value.items()} - elif not tree.is_nested(value): - return func(value) - else: - return value - - return tree.traverse(_map_value, structure, top_down=False) - - @_CustomSingleDispatchCallable def defdata[T]( __dispatch: Callable[[type], Callable[..., Expr[T]]], @@ -960,9 +948,6 @@ def _(op, *args, **kwargs): *{k: (v, kwarg_ctxs[k]) for k, v in kwargs.items()}.items(), ): if c: - v = _map_structure_and_keys( - lambda a: renaming.get(a, a) if isinstance(a, Operation) else a, v - ) res = evaluate( v, intp={ @@ -1133,21 +1118,28 @@ def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool: if isinstance(x, Term) and isinstance(other, Term): op, args, kwargs = x.op, x.args, x.kwargs op2, args2, kwargs2 = other.op, other.args, other.kwargs - try: - tree.assert_same_structure( - (op, args, kwargs), (op2, args2, kwargs2), check_types=True - ) - except (TypeError, ValueError): - return False - return all( - tree.flatten( - tree.map_structure( - syntactic_eq, (op, args, kwargs), (op2, args2, kwargs2) - ) - ) + return ( + op == op2 + and len(args) == len(args2) + and set(kwargs) == set(kwargs2) + and all(syntactic_eq(a, b) for a, b in zip(args, args2)) + and all(syntactic_eq(kwargs[k], kwargs2[k]) for k in kwargs) ) elif isinstance(x, Term) or isinstance(other, Term): return False + elif isinstance(x, collections.abc.Mapping) and isinstance( + other, collections.abc.Mapping + ): + return all( + k in x and k in other and syntactic_eq(x[k], other[k]) + for k in set(x) | set(other) + ) + elif isinstance(x, collections.abc.Sequence) and isinstance( + other, collections.abc.Sequence + ): + return len(x) == len(other) and all( + syntactic_eq(a, b) for a, b in zip(x, other) + ) else: return x == other diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index d4e06de1..2d6fb215 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -10,7 +10,6 @@ from effectful.ops.syntax import ( Scoped, _CustomSingleDispatchCallable, - _map_structure_and_keys, deffn, defop, defstream, @@ -111,13 +110,6 @@ def f(x): assert f_op != ff_op -def test_map_structure_and_keys(): - s = {1: 2, 3: [4, 5, (6, {7: 8})]} - expected = {2: 3, 4: [5, 6, (7, {8: 9})]} - actual = _map_structure_and_keys(lambda x: x + 1, s) - assert actual == expected - - def test_scoped_collections(): """Test that Scoped annotations work with tree-structured collections containing Operations."""