Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions effectful/ops/semantics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import collections.abc
import contextlib
import functools
import types
import typing
from collections.abc import Callable
from typing import Any

import tree

from effectful.ops.syntax import deffn, defop
from effectful.ops.types import Expr, Interpretation, Operation, Term

Expand Down Expand Up @@ -68,7 +67,7 @@ def call[**P, T](fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
}
with handler(subs):
return evaluate(body)
elif not any(isinstance(a, Term) for a in tree.flatten((fn, args, kwargs))):
elif not fvsof((fn, args, kwargs)):
return fn(*args, **kwargs)
else:
raise NotImplementedError
Expand Down Expand Up @@ -246,18 +245,37 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]
6

"""
if intp is None:
from effectful.internals.runtime import get_interpretation
from effectful.internals.runtime import get_interpretation, interpreter

intp = get_interpretation()
if intp is not None:
return interpreter(intp)(evaluate)(expr)

if isinstance(expr, Term):
(args, kwargs) = tree.map_structure(
functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs)
)
return apply.__default_rule__(intp, expr.op, *args, **kwargs)
elif tree.is_nested(expr):
return tree.map_structure(functools.partial(evaluate, intp=intp), expr)
args = tuple(evaluate(arg) for arg in expr.args)
kwargs = {k: evaluate(v) for k, v in expr.kwargs.items()}
return expr.op(*args, **kwargs)
elif isinstance(expr, Operation):
op_intp = get_interpretation().get(expr, expr)
return op_intp if isinstance(op_intp, Operation) else expr # type: ignore
elif isinstance(expr, collections.abc.Mapping):
if isinstance(expr, collections.defaultdict):
return type(expr)(expr.default_factory, evaluate(tuple(expr.items()))) # type: ignore
elif isinstance(expr, types.MappingProxyType):
return type(expr)(dict(evaluate(tuple(expr.items())))) # type: ignore
else:
return type(expr)(evaluate(tuple(expr.items()))) # type: ignore
elif isinstance(expr, collections.abc.Sequence):
if isinstance(expr, str | bytes):
return typing.cast(T, expr) # mypy doesnt like ignore here, so we use cast
else:
return type(expr)(evaluate(item) for item in expr) # type: ignore
elif isinstance(expr, collections.abc.Set):
if isinstance(expr, collections.abc.ItemsView | collections.abc.KeysView):
return {evaluate(item) for item in expr} # type: ignore
else:
return type(expr)(evaluate(item) for item in expr) # type: ignore
elif isinstance(expr, collections.abc.ValuesView):
return [evaluate(item) for item in expr] # type: ignore
else:
return expr

Expand Down
54 changes: 23 additions & 31 deletions effectful/ops/syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from collections.abc import Callable, Iterable, Mapping
from typing import Annotated, Concatenate

import tree

from effectful.ops.types import Annotation, Expr, Operation, Term


Expand Down Expand Up @@ -355,7 +353,7 @@ def analyze(self, bound_sig: inspect.BoundArguments) -> frozenset[Operation]:
else:
param_bound_vars = {param_value}
elif param_ordinal: # Only process if there's a Scoped annotation
# We can't use tree.flatten here because we want to be able
# We can't use flatten here because we want to be able
# to see dict keys
def extract_operations(obj):
if isinstance(obj, Operation):
Expand Down Expand Up @@ -662,7 +660,9 @@ def func() -> t: # type: ignore
def _[**P, T](t: Callable[P, T], *, name: str | None = None) -> Operation[P, T]:
@functools.wraps(t)
def func(*args, **kwargs):
if not any(isinstance(a, Term) for a in tree.flatten((args, kwargs))):
from effectful.ops.semantics import fvsof

if not fvsof((args, kwargs)):
return t(*args, **kwargs)
else:
raise NotImplementedError
Expand Down Expand Up @@ -872,18 +872,6 @@ def defterm[T](__dispatch: Callable[[type], Callable[[T], Expr[T]]], value: T):
return __dispatch(type(value))(value)


def _map_structure_and_keys(func, structure):
def _map_value(value):
if isinstance(value, dict):
return {func(k): v for k, v in value.items()}
elif not tree.is_nested(value):
return func(value)
else:
return value

return tree.traverse(_map_value, structure, top_down=False)


@_CustomSingleDispatchCallable
def defdata[T](
__dispatch: Callable[[type], Callable[..., Expr[T]]],
Expand Down Expand Up @@ -960,9 +948,6 @@ def _(op, *args, **kwargs):
*{k: (v, kwarg_ctxs[k]) for k, v in kwargs.items()}.items(),
):
if c:
v = _map_structure_and_keys(
lambda a: renaming.get(a, a) if isinstance(a, Operation) else a, v
)
res = evaluate(
v,
intp={
Expand Down Expand Up @@ -1133,21 +1118,28 @@ def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool:
if isinstance(x, Term) and isinstance(other, Term):
op, args, kwargs = x.op, x.args, x.kwargs
op2, args2, kwargs2 = other.op, other.args, other.kwargs
try:
tree.assert_same_structure(
(op, args, kwargs), (op2, args2, kwargs2), check_types=True
)
except (TypeError, ValueError):
return False
return all(
tree.flatten(
tree.map_structure(
syntactic_eq, (op, args, kwargs), (op2, args2, kwargs2)
)
)
return (
op == op2
and len(args) == len(args2)
and set(kwargs) == set(kwargs2)
and all(syntactic_eq(a, b) for a, b in zip(args, args2))
and all(syntactic_eq(kwargs[k], kwargs2[k]) for k in kwargs)
)
elif isinstance(x, Term) or isinstance(other, Term):
return False
elif isinstance(x, collections.abc.Mapping) and isinstance(
other, collections.abc.Mapping
):
return all(
k in x and k in other and syntactic_eq(x[k], other[k])
for k in set(x) | set(other)
)
elif isinstance(x, collections.abc.Sequence) and isinstance(
other, collections.abc.Sequence
):
return len(x) == len(other) and all(
syntactic_eq(a, b) for a, b in zip(x, other)
)
else:
return x == other

Expand Down
8 changes: 0 additions & 8 deletions tests/test_ops_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from effectful.ops.syntax import (
Scoped,
_CustomSingleDispatchCallable,
_map_structure_and_keys,
deffn,
defop,
defstream,
Expand Down Expand Up @@ -111,13 +110,6 @@ def f(x):
assert f_op != ff_op


def test_map_structure_and_keys():
s = {1: 2, 3: [4, 5, (6, {7: 8})]}
expected = {2: 3, 4: [5, 6, (7, {8: 9})]}
actual = _map_structure_and_keys(lambda x: x + 1, s)
assert actual == expected


def test_scoped_collections():
"""Test that Scoped annotations work with tree-structured collections containing Operations."""

Expand Down