From 003cd54908308a8c18ab71418a592ef670b733df Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 1 Jul 2025 16:34:59 -0400 Subject: [PATCH 01/76] stash --- effectful/ops/semantics.py | 37 +++++++++++++++++------------ effectful/ops/syntax.py | 48 +++++++++++++++++++++++++++++++++++--- effectful/ops/types.py | 3 ++- tests/test_ops_syntax.py | 18 ++++++++++++++ 4 files changed, 87 insertions(+), 19 deletions(-) diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 38eb0c86..f6ce276c 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -6,7 +6,7 @@ import tree from typing_extensions import ParamSpec -from effectful.ops.syntax import deffn, defop +from effectful.ops.syntax import deffn, defop, defterm from effectful.ops.types import Expr, Interpretation, Operation, Term P = ParamSpec("P") @@ -73,7 +73,7 @@ def call(fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: } with handler(subs): return evaluate(body) - elif not any(isinstance(a, Term) for a in tree.flatten((fn, args, kwargs))): + elif not any(isinstance(a, Term) for a in (fn, *args, *kwargs.values())): return fn(*args, **kwargs) else: raise NotImplementedError @@ -251,20 +251,27 @@ def evaluate(expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]: 6 """ - if intp is None: - from effectful.internals.runtime import get_interpretation - - intp = get_interpretation() - - if isinstance(expr, Term): - (args, kwargs) = tree.map_structure( - functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs) - ) - return apply.__default_rule__(intp, expr.op, *args, **kwargs) - elif tree.is_nested(expr): - return tree.map_structure(functools.partial(evaluate, intp=intp), expr) + if intp is not None: + from effectful.internals.runtime import interpreter + return interpreter(intp)(evaluate)(expr) + + tm = defterm(expr) + if isinstance(tm, Term): + args = tuple(evaluate(arg) for arg in tm.args) + kwargs = {k: evaluate(v) for k, v in tm.kwargs.items()} + return tm.op(*args, **kwargs) else: - return expr + return tm + + #if isinstance(expr, Term): + # (args, kwargs) = tree.map_structure( + # functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs) + # ) + # return apply.__default_rule__(intp, expr.op, *args, **kwargs) + #elif tree.is_nested(expr): + # return tree.map_structure(functools.partial(evaluate, intp=intp), expr) + #else: + # return expr def typeof(term: Expr[T]) -> type[T]: diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 904e0ba0..339779f6 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -689,7 +689,7 @@ def func() -> t: # type: ignore def _(t: Callable[P, T], *, name: str | None = None) -> Operation[P, T]: @functools.wraps(t) def func(*args, **kwargs): - if not any(isinstance(a, Term) for a in tree.flatten((args, kwargs))): + if not any(isinstance(defterm(a), Term) for a in (*args, *kwargs.values())): return t(*args, **kwargs) else: raise NotImplementedError @@ -1002,6 +1002,8 @@ def _(op, *args, **kwargs): return typed_term +@defterm.register(str) +@defterm.register(bytes) @defterm.register(object) @defterm.register(Operation) @defterm.register(Term) @@ -1118,7 +1120,7 @@ def defstream( @defdata.register(collections.abc.Iterable) -class _IterableTerm(Generic[T], _BaseTerm[collections.abc.Iterable[T]]): +class _IterableTerm(Generic[T], _BaseTerm[collections.abc.Iterable[T]], collections.abc.Iterable[T]): @defop def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: if not isinstance(self, Term): @@ -1128,7 +1130,7 @@ def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: @defdata.register(collections.abc.Iterator) -class _IteratorTerm(Generic[T], _IterableTerm[T]): +class _IteratorTerm(Generic[T], _IterableTerm[T], collections.abc.Iterator[T]): @defop def __next__(self: collections.abc.Iterator[T]) -> T: if not isinstance(self, Term): @@ -1141,6 +1143,46 @@ def __next__(self: collections.abc.Iterator[T]) -> T: next_ = _IteratorTerm.__next__ +@defdata.register(collections.abc.Collection) +class _CollectionTerm(Generic[T], _IterableTerm[T]): + @defop + def __contains__(self: collections.abc.Collection[T], item: T) -> bool: + if not isinstance(self, Term): + return item in self + else: + raise NotImplementedError + + @defop + def __len__(self: collections.abc.Collection[T]) -> int: + if not isinstance(self, Term): + return len(self) + else: + raise NotImplementedError + + +@defterm.register(collections.abc.Sequence) +def _(value: collections.abc.Sequence[Expr[T]]) -> Expr[collections.abc.Sequence[T]]: + tp = type(value) + @defop + def _reconstructor(*items: Expr[T]) -> tp: # type: ignore + if not any(isinstance(e, Term) for e in items): + return tree.sequence._sequence_like(value, items) + else: + raise NotImplementedError + + return _reconstructor(*value) + + +@defdata.register(collections.abc.Sequence) +class _SequenceTerm(Generic[T], _CollectionTerm[T], collections.abc.Sequence[T]): + @defop + def __getitem__(self: collections.abc.Sequence[T], index: int) -> T: + if not isinstance(self, Term): + return self[index] + else: + raise NotImplementedError + + def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool: """Syntactic equality, ignoring the interpretation of the terms. diff --git a/effectful/ops/types.py b/effectful/ops/types.py index ec313290..7302a9f7 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -84,8 +84,9 @@ def __fvs_rule__( def __call__(self, *args: Q.args, **kwargs: Q.kwargs) -> V: from effectful.internals.runtime import get_interpretation from effectful.ops.semantics import apply + from effectful.ops.syntax import defterm - return apply.__default_rule__(get_interpretation(), self, *args, **kwargs) # type: ignore + return apply.__default_rule__(get_interpretation(), self, *(defterm(a) for a in args), **{k: defterm(v) for k, v in kwargs.items()}) # type: ignore def __repr__(self): return f"{self.__class__.__name__}({self.__name__}, {self.__signature__})" diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 81d0be54..edb278cc 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -540,3 +540,21 @@ def test_defstream_1(): # assert isinstance(tm_iter_next, numbers.Number) # TODO # assert issubclass(typeof(tm_iter_next), numbers.Number) assert tm_iter_next.op is next_ + + +def test_defterm_sequence(): + @defop + def my_sequence(xs: tuple[int, ...]) -> tuple[int, ...]: + raise NotImplementedError + + x = defop(int, name="x") + y = defop(int, name="y") + tm = my_sequence((4, 5, x() + 1)) + + assert isinstance(tm, Term) + assert tm.op is my_sequence + assert tm.args == ((4, 5, x() + 1),) + + # Test that the term can be evaluated + with handler({my_sequence: lambda xs: tuple(x * 2 for x in xs), x: lambda: 0}): + assert evaluate(tm) == (8, 10, 2) From 4ecc29e870313a7a8eac2d7101b11083a4965841 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 1 Jul 2025 17:53:24 -0400 Subject: [PATCH 02/76] avoid defterm --- effectful/ops/semantics.py | 26 ++++++++++---------------- effectful/ops/syntax.py | 25 +++++++++++-------------- effectful/ops/types.py | 3 +-- tests/test_ops_syntax.py | 8 ++++++++ 4 files changed, 30 insertions(+), 32 deletions(-) diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index f6ce276c..1ff4db22 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -1,3 +1,4 @@ +import collections.abc import contextlib import functools from collections.abc import Callable @@ -255,23 +256,16 @@ def evaluate(expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]: from effectful.internals.runtime import interpreter return interpreter(intp)(evaluate)(expr) - tm = defterm(expr) - if isinstance(tm, Term): - args = tuple(evaluate(arg) for arg in tm.args) - kwargs = {k: evaluate(v) for k, v in tm.kwargs.items()} - return tm.op(*args, **kwargs) + if isinstance(expr, Term): + args = tuple(evaluate(arg) for arg in expr.args) + kwargs = {k: evaluate(v) for k, v in expr.kwargs.items()} + return expr.op(*args, **kwargs) + elif isinstance(expr, collections.abc.Mapping) and not isinstance(expr, Interpretation): + return type(expr)((evaluate(k), evaluate(v)) for k, v in expr.items()) + elif isinstance(expr, collections.abc.Sequence) and not isinstance(expr, str): + return type(expr)(evaluate(e) for e in expr) else: - return tm - - #if isinstance(expr, Term): - # (args, kwargs) = tree.map_structure( - # functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs) - # ) - # return apply.__default_rule__(intp, expr.op, *args, **kwargs) - #elif tree.is_nested(expr): - # return tree.map_structure(functools.partial(evaluate, intp=intp), expr) - #else: - # return expr + return expr def typeof(term: Expr[T]) -> type[T]: diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 339779f6..62abba07 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1147,7 +1147,7 @@ def __next__(self: collections.abc.Iterator[T]) -> T: class _CollectionTerm(Generic[T], _IterableTerm[T]): @defop def __contains__(self: collections.abc.Collection[T], item: T) -> bool: - if not isinstance(self, Term): + if not isinstance(self, Term) and not isinstance(item, Term): return item in self else: raise NotImplementedError @@ -1160,25 +1160,22 @@ def __len__(self: collections.abc.Collection[T]) -> int: raise NotImplementedError -@defterm.register(collections.abc.Sequence) -def _(value: collections.abc.Sequence[Expr[T]]) -> Expr[collections.abc.Sequence[T]]: - tp = type(value) +@defdata.register(collections.abc.Sequence) +class _SequenceTerm(Generic[T], _CollectionTerm[T], collections.abc.Sequence[T]): @defop - def _reconstructor(*items: Expr[T]) -> tp: # type: ignore - if not any(isinstance(e, Term) for e in items): - return tree.sequence._sequence_like(value, items) + def __getitem__(self: collections.abc.Sequence[T], index: int) -> T: + if not isinstance(self, Term) and not isinstance(index, Term): + return self[index] else: raise NotImplementedError - return _reconstructor(*value) - -@defdata.register(collections.abc.Sequence) -class _SequenceTerm(Generic[T], _CollectionTerm[T], collections.abc.Sequence[T]): +@defdata.register(collections.abc.Mapping) +class _MappingTerm(Generic[S, V], _CollectionTerm[S], collections.abc.Mapping[S, V]): @defop - def __getitem__(self: collections.abc.Sequence[T], index: int) -> T: - if not isinstance(self, Term): - return self[index] + def __getitem__(self: collections.abc.Mapping[S, V], key: S) -> V: + if not isinstance(self, Term) and not isinstance(key, Term): + return self[key] else: raise NotImplementedError diff --git a/effectful/ops/types.py b/effectful/ops/types.py index 7302a9f7..ec313290 100644 --- a/effectful/ops/types.py +++ b/effectful/ops/types.py @@ -84,9 +84,8 @@ def __fvs_rule__( def __call__(self, *args: Q.args, **kwargs: Q.kwargs) -> V: from effectful.internals.runtime import get_interpretation from effectful.ops.semantics import apply - from effectful.ops.syntax import defterm - return apply.__default_rule__(get_interpretation(), self, *(defterm(a) for a in args), **{k: defterm(v) for k, v in kwargs.items()}) # type: ignore + return apply.__default_rule__(get_interpretation(), self, *args, **kwargs) # type: ignore def __repr__(self): return f"{self.__class__.__name__}({self.__name__}, {self.__signature__})" diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index edb278cc..bcf8703d 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -1,3 +1,4 @@ +import collections.abc import functools import inspect from collections.abc import Callable, Iterable, Iterator, Mapping @@ -552,9 +553,16 @@ def my_sequence(xs: tuple[int, ...]) -> tuple[int, ...]: tm = my_sequence((4, 5, x() + 1)) assert isinstance(tm, Term) + assert issubclass(typeof(tm), collections.abc.Sequence) assert tm.op is my_sequence assert tm.args == ((4, 5, x() + 1),) + tm_0 = tm[0] + assert isinstance(tm_0, Term) + assert isinstance(tm_0.op, Operation) + assert tm_0.op is type(tm).__getitem__ + assert tm_0.args == (tm, 0) + # Test that the term can be evaluated with handler({my_sequence: lambda xs: tuple(x * 2 for x in xs), x: lambda: 0}): assert evaluate(tm) == (8, 10, 2) From 76abd887ffaca22f67bf9bd803ace1c3eaa53e3e Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 1 Jul 2025 17:57:30 -0400 Subject: [PATCH 03/76] format --- effectful/ops/semantics.py | 3 +-- tests/test_ops_syntax.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 1ff4db22..03fa2796 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -4,10 +4,9 @@ from collections.abc import Callable from typing import Any, TypeVar -import tree from typing_extensions import ParamSpec -from effectful.ops.syntax import deffn, defop, defterm +from effectful.ops.syntax import deffn, defop from effectful.ops.types import Expr, Interpretation, Operation, Term P = ParamSpec("P") diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index bcf8703d..9dbdf35a 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -553,6 +553,7 @@ def my_sequence(xs: tuple[int, ...]) -> tuple[int, ...]: tm = my_sequence((4, 5, x() + 1)) assert isinstance(tm, Term) + assert isinstance(tm, collections.abc.Sequence) assert issubclass(typeof(tm), collections.abc.Sequence) assert tm.op is my_sequence assert tm.args == ((4, 5, x() + 1),) @@ -566,3 +567,29 @@ def my_sequence(xs: tuple[int, ...]) -> tuple[int, ...]: # Test that the term can be evaluated with handler({my_sequence: lambda xs: tuple(x * 2 for x in xs), x: lambda: 0}): assert evaluate(tm) == (8, 10, 2) + + +def test_defterm_mapping(): + @defop + def my_mapping(k: int, v: int) -> Mapping[int, int]: + raise NotImplementedError + + x = defop(int, name="x") + tm = my_mapping(3, x() + 1) + + assert isinstance(tm, Term) + assert isinstance(tm, collections.abc.Mapping) + assert issubclass(typeof(tm), collections.abc.Mapping) + assert tm.op is my_mapping + assert tm.args == (3, x() + 1) + assert tm.kwargs == {} + + tm_3 = tm[3] + assert isinstance(tm_3, Term) + assert isinstance(tm_3.op, Operation) + assert tm_3.op is type(tm).__getitem__ + assert tm_3.args == (tm, 3) + + # Test that the term can be evaluated + with handler({my_mapping: lambda k, v: {k: v}, x: lambda: 0}): + assert evaluate(tm) == {3: 1} From abd67c674dc6eb74b91c88451d5f04c91353e6c6 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 1 Jul 2025 18:12:37 -0400 Subject: [PATCH 04/76] nit --- tests/test_ops_syntax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 9dbdf35a..9751dcb0 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -571,7 +571,7 @@ def my_sequence(xs: tuple[int, ...]) -> tuple[int, ...]: def test_defterm_mapping(): @defop - def my_mapping(k: int, v: int) -> Mapping[int, int]: + def my_mapping(k: int, v: int) -> dict[int, int]: raise NotImplementedError x = defop(int, name="x") From 0aae217d359e4a4dc4229f28078a8bcee70d39a0 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 2 Jul 2025 18:09:33 -0400 Subject: [PATCH 05/76] stash --- effectful/ops/syntax.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 62abba07..5df14568 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1,3 +1,4 @@ +import abc import collections.abc import dataclasses import functools @@ -545,6 +546,10 @@ def __init__( self._freshening = freshening or [] self.__signature__ = inspect.signature(default) + @property + def __isabstractmethod__(self) -> bool: + return False + def __eq__(self, other): if not isinstance(other, Operation): return NotImplemented @@ -1047,6 +1052,33 @@ def kwargs(self): return self._kwargs +@functools.cache +def termcls(interface: type[abc.ABC]) -> type[Term]: + """ + Create a Term subclass that converts the abstract methods of the given ABC into Operations. + """ + if issubclass(interface, Term) or not interface.__abstractmethods__: + raise TypeError(f"cls must be an abstract non-Term class with abstract methods, got {interface.__name__}") + + if typing.get_origin(interface) not in (None, interface): + return termcls(typing.get_origin(interface)) + + @abc.update_abstractmethods + class _ProxyTerm(_BaseTerm, interface): + for _attr in interface.__abstractmethods__: + if _attr in interface.__dict__: + locals()[_attr] = defop(interface.__dict__[_attr]) + else: + for _basecls in inspect.getmro(interface): + if _attr in _basecls.__dict__ and isinstance(_basecls, abc.ABCMeta) and _basecls.__dict__[_attr].__isabstractmethod__: + locals()[_attr] = termcls(_basecls).__dict__[_attr] + break + del _basecls + del _attr + + return _ProxyTerm + + @defdata.register(collections.abc.Callable) class _CallableTerm(Generic[P, T], _BaseTerm[collections.abc.Callable[P, T]]): def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]: From 24bf5cb634cb533186dc56e843c2e742d9c399e2 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Jul 2025 09:59:15 -0400 Subject: [PATCH 06/76] fix bugs --- effectful/ops/semantics.py | 10 ++++- effectful/ops/syntax.py | 77 ++++++++++++++------------------------ tests/test_ops_syntax.py | 1 - 3 files changed, 36 insertions(+), 52 deletions(-) diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 03fa2796..f747309a 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -251,18 +251,24 @@ def evaluate(expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]: 6 """ + from effectful.internals.runtime import get_interpretation, interpreter + if intp is not None: - from effectful.internals.runtime import interpreter return interpreter(intp)(evaluate)(expr) if isinstance(expr, Term): args = tuple(evaluate(arg) for arg in expr.args) kwargs = {k: evaluate(v) for k, v in expr.kwargs.items()} return expr.op(*args, **kwargs) - elif isinstance(expr, collections.abc.Mapping) and not isinstance(expr, Interpretation): + elif isinstance(expr, collections.abc.Mapping): return type(expr)((evaluate(k), evaluate(v)) for k, v in expr.items()) elif isinstance(expr, collections.abc.Sequence) and not isinstance(expr, str): return type(expr)(evaluate(e) for e in expr) + elif isinstance(expr, collections.abc.Set): + return type(expr)(evaluate(e) for e in expr) + elif isinstance(expr, Operation): + op_intp = get_interpretation().get(expr, expr) + return op_intp if isinstance(op_intp, Operation) else expr else: return expr diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 5df14568..7ff75f9f 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1,4 +1,3 @@ -import abc import collections.abc import dataclasses import functools @@ -9,7 +8,6 @@ from collections.abc import Callable, Iterable, Mapping from typing import Annotated, Concatenate, Generic, TypeVar -import tree from typing_extensions import ParamSpec from effectful.ops.types import Annotation, Expr, Operation, Term @@ -366,16 +364,17 @@ def analyze(self, bound_sig: inspect.BoundArguments) -> frozenset[Operation]: else: param_bound_vars = {param_value} elif param_ordinal: # Only process if there's a Scoped annotation - # We can't use tree.flatten here because we want to be able - # to see dict keys + def extract_operations(obj): if isinstance(obj, Operation): param_bound_vars.add(obj) - elif isinstance(obj, dict): + elif isinstance(obj, collections.abc.Mapping): for k, v in obj.items(): extract_operations(k) extract_operations(v) - elif isinstance(obj, list | set | tuple): + elif isinstance( + obj, collections.abc.Sequence | collections.abc.Set + ): for v in obj: extract_operations(v) @@ -675,6 +674,7 @@ def func(*args, **kwargs): @defop.register(type) +@defop.register(types.GenericAlias) def _(t: type[T], *, name: str | None = None) -> Operation[[], T]: def func() -> t: # type: ignore raise NotImplementedError @@ -982,9 +982,6 @@ def _(op, *args, **kwargs): *{k: (v, kwarg_ctxs[k]) for k, v in kwargs.items()}.items(), ): if c: - v = tree.map_structure( - lambda a: renaming.get(a, a) if isinstance(a, Operation) else a, v - ) res = evaluate( v, intp={ @@ -1052,33 +1049,6 @@ def kwargs(self): return self._kwargs -@functools.cache -def termcls(interface: type[abc.ABC]) -> type[Term]: - """ - Create a Term subclass that converts the abstract methods of the given ABC into Operations. - """ - if issubclass(interface, Term) or not interface.__abstractmethods__: - raise TypeError(f"cls must be an abstract non-Term class with abstract methods, got {interface.__name__}") - - if typing.get_origin(interface) not in (None, interface): - return termcls(typing.get_origin(interface)) - - @abc.update_abstractmethods - class _ProxyTerm(_BaseTerm, interface): - for _attr in interface.__abstractmethods__: - if _attr in interface.__dict__: - locals()[_attr] = defop(interface.__dict__[_attr]) - else: - for _basecls in inspect.getmro(interface): - if _attr in _basecls.__dict__ and isinstance(_basecls, abc.ABCMeta) and _basecls.__dict__[_attr].__isabstractmethod__: - locals()[_attr] = termcls(_basecls).__dict__[_attr] - break - del _basecls - del _attr - - return _ProxyTerm - - @defdata.register(collections.abc.Callable) class _CallableTerm(Generic[P, T], _BaseTerm[collections.abc.Callable[P, T]]): def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]: @@ -1152,7 +1122,9 @@ def defstream( @defdata.register(collections.abc.Iterable) -class _IterableTerm(Generic[T], _BaseTerm[collections.abc.Iterable[T]], collections.abc.Iterable[T]): +class _IterableTerm( + Generic[T], _BaseTerm[collections.abc.Iterable[T]], collections.abc.Iterable[T] +): @defop def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: if not isinstance(self, Term): @@ -1224,21 +1196,28 @@ def syntactic_eq(x: Expr[T], other: Expr[T]) -> bool: if isinstance(x, Term) and isinstance(other, Term): op, args, kwargs = x.op, x.args, x.kwargs op2, args2, kwargs2 = other.op, other.args, other.kwargs - try: - tree.assert_same_structure( - (op, args, kwargs), (op2, args2, kwargs2), check_types=True - ) - except (TypeError, ValueError): - return False - return all( - tree.flatten( - tree.map_structure( - syntactic_eq, (op, args, kwargs), (op2, args2, kwargs2) - ) - ) + return ( + op == op2 + and len(args) == len(args2) + and set(kwargs) == set(kwargs2) + and all(syntactic_eq(a, b) for a, b in zip(args, args2)) + and all(syntactic_eq(kwargs[k], kwargs2[k]) for k in kwargs) ) elif isinstance(x, Term) or isinstance(other, Term): return False + elif isinstance(x, collections.abc.Mapping) and isinstance( + other, collections.abc.Mapping + ): + return all( + k in x and k in other and syntactic_eq(x[k], other[k]) + for k in set(x) | set(other) + ) + elif isinstance(x, collections.abc.Sequence) and isinstance( + other, collections.abc.Sequence + ): + return len(x) == len(other) and all( + syntactic_eq(a, b) for a, b in zip(x, other) + ) else: return x == other diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 9751dcb0..ad09efa0 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -549,7 +549,6 @@ def my_sequence(xs: tuple[int, ...]) -> tuple[int, ...]: raise NotImplementedError x = defop(int, name="x") - y = defop(int, name="y") tm = my_sequence((4, 5, x() + 1)) assert isinstance(tm, Term) From 284ac3e9d7741e4235aac21f1f8bd6909e7bd0bf Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Jul 2025 10:40:42 -0400 Subject: [PATCH 07/76] nit --- effectful/ops/syntax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 7ff75f9f..f6727772 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -694,7 +694,7 @@ def func() -> t: # type: ignore def _(t: Callable[P, T], *, name: str | None = None) -> Operation[P, T]: @functools.wraps(t) def func(*args, **kwargs): - if not any(isinstance(defterm(a), Term) for a in (*args, *kwargs.values())): + if not any(isinstance(a, Term) for a in (*args, *kwargs.values())): return t(*args, **kwargs) else: raise NotImplementedError From 49f678f96692149d57e3a15da809a82676317af7 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Jul 2025 15:51:40 -0400 Subject: [PATCH 08/76] first commit --- effectful/internals/unification.py | 154 ++++++++++++++++++++++++++++ tests/test_internals_unification.py | 44 ++++++++ 2 files changed, 198 insertions(+) create mode 100644 effectful/internals/unification.py create mode 100644 tests/test_internals_unification.py diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py new file mode 100644 index 00000000..fca04a13 --- /dev/null +++ b/effectful/internals/unification.py @@ -0,0 +1,154 @@ +import collections.abc +import inspect +import types +import typing + + +def infer_return_type(bound_sig: inspect.BoundArguments) -> type: + """ + Infer the return type of a function based on its signature and argument types. + """ + bound_sig.apply_defaults() + sig: inspect.Signature = bound_sig.signature + + # validate that the function has a signature with well-formed type annotations + if sig.return_annotation is inspect.Signature.empty: + raise TypeError("Function must have a return type annotation") + + if any(p.annotation is inspect.Signature.empty for p in sig.parameters.values()): + raise TypeError("All parameters must have type annotations") + + result_fvs: set[typing.TypeVar] = freetypevars(sig.return_annotation) + pattern_fvs: set[typing.TypeVar] = set.union( + *(freetypevars(p.annotation) for p in sig.parameters.values()), + ) + concrete_fvs: set[typing.TypeVar] = set.union( + *(freetypevars(arg) for arg in bound_sig.arguments.values()), + ) + if (result_fvs | pattern_fvs) & concrete_fvs: + raise TypeError("Cannot unify free type variables in pattern and concrete types") + if not result_fvs <= pattern_fvs: + raise TypeError("unbound type variables in return type") + + # Check for variadic parameters and collections - not implemented yet + for name, param in sig.parameters.items(): + if param.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}: + raise NotImplementedError(f"Parameter '{name}' cannot be variadic") + + if isinstance(bound_sig.arguments[name], collections.abc.Collection): + raise NotImplementedError(f"Parameter '{name}' cannot be a collection type") + + if freetypevars(bound_sig.arguments[name]): + raise NotImplementedError(f"Parameter '{name}' cannot have free type variables") + + # Build substitution map + subs = {} + for name in sig.parameters: + subs = unify(sig.parameters[name].annotation, bound_sig.arguments[name], subs) + + # Apply substitutions to return type + result_type: type = substitute(sig.return_annotation, subs) + if freetypevars(result_type): + raise TypeError("Return type cannot have free type variables after substitution") + return result_type + + +def unify( + pattern: type | types.GenericAlias | collections.abc.Mapping | collections.abc.Sequence, + concrete: type | types.GenericAlias | collections.abc.Mapping | collections.abc.Sequence, + subs: collections.abc.Mapping[typing.TypeVar, type], +) -> collections.abc.Mapping[typing.TypeVar, type]: + """ + Unify a pattern type with a concrete type, returning a substitution map. + Raises TypeError if unification is not possible. + """ + if isinstance(pattern, typing.TypeVar): + if pattern in subs and subs[pattern] != concrete: + raise TypeError(f"Cannot unify {pattern} with {concrete} (already unified with {subs[pattern]})") + return {**subs, pattern: concrete} + elif typing.get_args(pattern) and typing.get_args(concrete): + if typing.get_origin(pattern) != typing.get_origin(concrete): + raise TypeError(f"Cannot unify {pattern} with {concrete}") + return unify(typing.get_args(pattern), typing.get_args(concrete), subs) + elif isinstance(pattern, collections.abc.Mapping) and isinstance(concrete, collections.abc.Mapping): + if pattern.keys() != concrete.keys(): + raise TypeError(f"Cannot unify {pattern} with {concrete}") + for key in pattern: + subs = unify(pattern[key], concrete[key], subs) + return subs + elif isinstance(pattern, collections.abc.Sequence) and isinstance(concrete, collections.abc.Sequence): + if len(pattern) != len(concrete): + raise TypeError(f"Cannot unify {pattern} with {concrete}") + for p_item, c_item in zip(pattern, concrete): + subs = unify(p_item, c_item, subs) + return subs + else: + if pattern != concrete: + raise TypeError(f"Cannot unify {pattern} with {concrete}") + return subs + + +def freetypevars(typ: type | types.GenericAlias) -> set[typing.TypeVar]: + """ + Return a set of free type variables in the given type. + """ + if isinstance(typ, typing.TypeVar): + return {typ} + elif typing.get_args(typ): + return set.union(*(freetypevars(arg) for arg in typing.get_args(typ))) + else: + return set() + + +def substitute(typ: type | types.GenericAlias, subs: collections.abc.Mapping[typing.TypeVar, type]) -> type: + """ + Substitute type variables in a type expression with concrete types. + + This function recursively traverses a type expression and replaces any TypeVar + instances found with their corresponding concrete types from the substitution + mapping. If a TypeVar is not present in the substitution mapping, it remains + unchanged. The function handles nested generic types by recursively substituting + in their type arguments. + + Args: + typ: The type expression to perform substitution on. Can be a plain type, + a TypeVar, or a generic type alias (e.g., List[T], Dict[K, V]). + subs: A mapping from TypeVar instances to concrete types that should + replace them. + + Returns: + A new type expression with all mapped TypeVars replaced by their + corresponding concrete types. + + Examples: + >>> T = typing.TypeVar('T') + >>> K = typing.TypeVar('K') + >>> V = typing.TypeVar('V') + + >>> # Simple TypeVar substitution + >>> substitute(T, {T: int}) + + + >>> # Generic type substitution + >>> substitute(typing.List[T], {T: str}) + typing.List[str] + + >>> # Nested generic substitution + >>> substitute(typing.Dict[K, typing.List[V]], {K: str, V: int}) + typing.Dict[str, typing.List[int]] + + >>> # TypeVar not in mapping remains unchanged + >>> substitute(T, {K: int}) + ~T + + >>> # Non-generic types pass through unchanged + >>> substitute(int, {T: str}) + + """ + if isinstance(typ, typing.TypeVar): + return subs.get(typ, typ) + elif typing.get_args(typ) and typing.get_origin(typ) is not None: + new_args = tuple(substitute(arg, subs) for arg in typing.get_args(typ)) + return typing.get_origin(typ)[new_args] + else: + return typ diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py new file mode 100644 index 00000000..088c8086 --- /dev/null +++ b/tests/test_internals_unification.py @@ -0,0 +1,44 @@ +import typing + +import pytest + +from effectful.internals.unification import infer_return_type, unify, substitute, freetypevars + + +@pytest.mark.parametrize( +"typ,fvs", [ + (typing.List[typing.TypeVar("T")], {typing.TypeVar("T")}), + (typing.Dict[str, typing.TypeVar("T")], {typing.TypeVar("T")}), + (int, set()), + (typing.List[int], set()), + (typing.Dict[str, int], set()), +]) +def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): + assert freetypevars(typ) == fvs + + +@pytest.mark.parametrize( + "typ,subs,expected", [ + (typing.List[typing.TypeVar("T")], {typing.TypeVar("T"): int}, typing.List[int]), + (typing.Dict[str, typing.TypeVar("T")], {typing.TypeVar("T"): int}, typing.Dict[str, int]), + (int, {}, int), + (typing.List[int], {}, typing.List[int]), + (typing.Dict[str, int], {}, typing.Dict[str, int]), + ] +) +def test_substitute(typ: type, subs: typing.Mapping[typing.TypeVar, type], expected: type): + assert substitute(typ, subs) == expected + + +@pytest.mark.parametrize( + "pattern,concrete,subs,expected", [ + (typing.TypeVar("T"), int, {}, {typing.TypeVar("T"): int}), + (typing.List[typing.TypeVar("T")], typing.List[int], {typing.TypeVar("T"): int}), + ] +) +def test_unify(pattern: type, concrete: type, subs: typing.Mapping[typing.TypeVar, type], expected: typing.Mapping[typing.TypeVar, type]): + assert unify(pattern, concrete, subs) == expected + + +def test_infer_return_type(): + pass # TODO fill this in \ No newline at end of file From d99753b279c4d32dd6972a4b550b0bd504dcf7e5 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Jul 2025 15:55:50 -0400 Subject: [PATCH 09/76] format --- effectful/internals/unification.py | 77 +++++++++++++++++++---------- tests/test_internals_unification.py | 55 +++++++++++++-------- 2 files changed, 86 insertions(+), 46 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index fca04a13..c1b799bc 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -4,7 +4,7 @@ import typing -def infer_return_type(bound_sig: inspect.BoundArguments) -> type: +def infer_return_type(bound_sig: inspect.BoundArguments) -> type | types.GenericAlias: """ Infer the return type of a function based on its signature and argument types. """ @@ -14,7 +14,7 @@ def infer_return_type(bound_sig: inspect.BoundArguments) -> type: # validate that the function has a signature with well-formed type annotations if sig.return_annotation is inspect.Signature.empty: raise TypeError("Function must have a return type annotation") - + if any(p.annotation is inspect.Signature.empty for p in sig.parameters.values()): raise TypeError("All parameters must have type annotations") @@ -26,36 +26,51 @@ def infer_return_type(bound_sig: inspect.BoundArguments) -> type: *(freetypevars(arg) for arg in bound_sig.arguments.values()), ) if (result_fvs | pattern_fvs) & concrete_fvs: - raise TypeError("Cannot unify free type variables in pattern and concrete types") + raise TypeError( + "Cannot unify free type variables in pattern and concrete types" + ) if not result_fvs <= pattern_fvs: raise TypeError("unbound type variables in return type") # Check for variadic parameters and collections - not implemented yet for name, param in sig.parameters.items(): - if param.kind in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}: + if param.kind in { + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + }: raise NotImplementedError(f"Parameter '{name}' cannot be variadic") if isinstance(bound_sig.arguments[name], collections.abc.Collection): raise NotImplementedError(f"Parameter '{name}' cannot be a collection type") if freetypevars(bound_sig.arguments[name]): - raise NotImplementedError(f"Parameter '{name}' cannot have free type variables") + raise NotImplementedError( + f"Parameter '{name}' cannot have free type variables" + ) # Build substitution map - subs = {} + subs: collections.abc.Mapping[typing.TypeVar, type] = {} for name in sig.parameters: subs = unify(sig.parameters[name].annotation, bound_sig.arguments[name], subs) - + # Apply substitutions to return type - result_type: type = substitute(sig.return_annotation, subs) + result_type = substitute(sig.return_annotation, subs) if freetypevars(result_type): - raise TypeError("Return type cannot have free type variables after substitution") + raise TypeError( + "Return type cannot have free type variables after substitution" + ) return result_type def unify( - pattern: type | types.GenericAlias | collections.abc.Mapping | collections.abc.Sequence, - concrete: type | types.GenericAlias | collections.abc.Mapping | collections.abc.Sequence, + pattern: type + | types.GenericAlias + | collections.abc.Mapping + | collections.abc.Sequence, + concrete: type + | types.GenericAlias + | collections.abc.Mapping + | collections.abc.Sequence, subs: collections.abc.Mapping[typing.TypeVar, type], ) -> collections.abc.Mapping[typing.TypeVar, type]: """ @@ -64,19 +79,25 @@ def unify( """ if isinstance(pattern, typing.TypeVar): if pattern in subs and subs[pattern] != concrete: - raise TypeError(f"Cannot unify {pattern} with {concrete} (already unified with {subs[pattern]})") + raise TypeError( + f"Cannot unify {pattern} with {concrete} (already unified with {subs[pattern]})" + ) return {**subs, pattern: concrete} elif typing.get_args(pattern) and typing.get_args(concrete): if typing.get_origin(pattern) != typing.get_origin(concrete): raise TypeError(f"Cannot unify {pattern} with {concrete}") return unify(typing.get_args(pattern), typing.get_args(concrete), subs) - elif isinstance(pattern, collections.abc.Mapping) and isinstance(concrete, collections.abc.Mapping): + elif isinstance(pattern, collections.abc.Mapping) and isinstance( + concrete, collections.abc.Mapping + ): if pattern.keys() != concrete.keys(): raise TypeError(f"Cannot unify {pattern} with {concrete}") for key in pattern: subs = unify(pattern[key], concrete[key], subs) return subs - elif isinstance(pattern, collections.abc.Sequence) and isinstance(concrete, collections.abc.Sequence): + elif isinstance(pattern, collections.abc.Sequence) and isinstance( + concrete, collections.abc.Sequence + ): if len(pattern) != len(concrete): raise TypeError(f"Cannot unify {pattern} with {concrete}") for p_item, c_item in zip(pattern, concrete): @@ -100,55 +121,59 @@ def freetypevars(typ: type | types.GenericAlias) -> set[typing.TypeVar]: return set() -def substitute(typ: type | types.GenericAlias, subs: collections.abc.Mapping[typing.TypeVar, type]) -> type: +def substitute( + typ: type | types.GenericAlias, subs: collections.abc.Mapping[typing.TypeVar, type] +) -> type | types.GenericAlias: """ Substitute type variables in a type expression with concrete types. - + This function recursively traverses a type expression and replaces any TypeVar instances found with their corresponding concrete types from the substitution mapping. If a TypeVar is not present in the substitution mapping, it remains unchanged. The function handles nested generic types by recursively substituting in their type arguments. - + Args: typ: The type expression to perform substitution on. Can be a plain type, a TypeVar, or a generic type alias (e.g., List[T], Dict[K, V]). subs: A mapping from TypeVar instances to concrete types that should replace them. - + Returns: A new type expression with all mapped TypeVars replaced by their corresponding concrete types. - + Examples: >>> T = typing.TypeVar('T') >>> K = typing.TypeVar('K') >>> V = typing.TypeVar('V') - + >>> # Simple TypeVar substitution >>> substitute(T, {T: int}) - + >>> # Generic type substitution >>> substitute(typing.List[T], {T: str}) typing.List[str] - + >>> # Nested generic substitution >>> substitute(typing.Dict[K, typing.List[V]], {K: str, V: int}) typing.Dict[str, typing.List[int]] - + >>> # TypeVar not in mapping remains unchanged >>> substitute(T, {K: int}) ~T - + >>> # Non-generic types pass through unchanged >>> substitute(int, {T: str}) """ if isinstance(typ, typing.TypeVar): return subs.get(typ, typ) - elif typing.get_args(typ) and typing.get_origin(typ) is not None: + elif typing.get_args(typ): + origin = typing.get_origin(typ) + assert origin is not None, "Type must have an origin" new_args = tuple(substitute(arg, subs) for arg in typing.get_args(typ)) - return typing.get_origin(typ)[new_args] + return origin[new_args] else: return typ diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 088c8086..637263d9 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -2,43 +2,58 @@ import pytest -from effectful.internals.unification import infer_return_type, unify, substitute, freetypevars +from effectful.internals.unification import ( + freetypevars, + substitute, + unify, +) @pytest.mark.parametrize( -"typ,fvs", [ - (typing.List[typing.TypeVar("T")], {typing.TypeVar("T")}), - (typing.Dict[str, typing.TypeVar("T")], {typing.TypeVar("T")}), - (int, set()), - (typing.List[int], set()), - (typing.Dict[str, int], set()), -]) + "typ,fvs", + [ + (list[typing.TypeVar("T")], {typing.TypeVar("T")}), + (dict[str, typing.TypeVar("T")], {typing.TypeVar("T")}), + (int, set()), + (list[int], set()), + (dict[str, int], set()), + ], +) def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): assert freetypevars(typ) == fvs @pytest.mark.parametrize( - "typ,subs,expected", [ - (typing.List[typing.TypeVar("T")], {typing.TypeVar("T"): int}, typing.List[int]), - (typing.Dict[str, typing.TypeVar("T")], {typing.TypeVar("T"): int}, typing.Dict[str, int]), + "typ,subs,expected", + [ + (list[typing.TypeVar("T")], {typing.TypeVar("T"): int}, list[int]), + (dict[str, typing.TypeVar("T")], {typing.TypeVar("T"): int}, dict[str, int]), (int, {}, int), - (typing.List[int], {}, typing.List[int]), - (typing.Dict[str, int], {}, typing.Dict[str, int]), - ] + (list[int], {}, list[int]), + (dict[str, int], {}, dict[str, int]), + ], ) -def test_substitute(typ: type, subs: typing.Mapping[typing.TypeVar, type], expected: type): +def test_substitute( + typ: type, subs: typing.Mapping[typing.TypeVar, type], expected: type +): assert substitute(typ, subs) == expected @pytest.mark.parametrize( - "pattern,concrete,subs,expected", [ + "pattern,concrete,subs,expected", + [ (typing.TypeVar("T"), int, {}, {typing.TypeVar("T"): int}), - (typing.List[typing.TypeVar("T")], typing.List[int], {typing.TypeVar("T"): int}), - ] + (list[typing.TypeVar("T")], list[int], {typing.TypeVar("T"): int}), + ], ) -def test_unify(pattern: type, concrete: type, subs: typing.Mapping[typing.TypeVar, type], expected: typing.Mapping[typing.TypeVar, type]): +def test_unify( + pattern: type, + concrete: type, + subs: typing.Mapping[typing.TypeVar, type], + expected: typing.Mapping[typing.TypeVar, type], +): assert unify(pattern, concrete, subs) == expected def test_infer_return_type(): - pass # TODO fill this in \ No newline at end of file + pass # TODO fill this in From b4417f3c1ca2635342fe3c0a2753ae3c607ac082 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Jul 2025 15:59:17 -0400 Subject: [PATCH 10/76] test passes --- tests/test_internals_unification.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 637263d9..e6313aae 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -8,12 +8,14 @@ unify, ) +T = typing.TypeVar("T") + @pytest.mark.parametrize( "typ,fvs", [ - (list[typing.TypeVar("T")], {typing.TypeVar("T")}), - (dict[str, typing.TypeVar("T")], {typing.TypeVar("T")}), + (list[T], {T}), + (dict[str, T], {T}), (int, set()), (list[int], set()), (dict[str, int], set()), @@ -26,8 +28,8 @@ def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): @pytest.mark.parametrize( "typ,subs,expected", [ - (list[typing.TypeVar("T")], {typing.TypeVar("T"): int}, list[int]), - (dict[str, typing.TypeVar("T")], {typing.TypeVar("T"): int}, dict[str, int]), + (list[T], {T: int}, list[int]), + (dict[str, T], {T: int}, dict[str, int]), (int, {}, int), (list[int], {}, list[int]), (dict[str, int], {}, dict[str, int]), @@ -40,19 +42,18 @@ def test_substitute( @pytest.mark.parametrize( - "pattern,concrete,subs,expected", + "pattern,concrete,expected_subs", [ - (typing.TypeVar("T"), int, {}, {typing.TypeVar("T"): int}), - (list[typing.TypeVar("T")], list[int], {typing.TypeVar("T"): int}), + (T, int, {T: int}), + (list[T], list[int], {T: int}), ], ) def test_unify( pattern: type, concrete: type, - subs: typing.Mapping[typing.TypeVar, type], - expected: typing.Mapping[typing.TypeVar, type], + expected_subs: typing.Mapping[typing.TypeVar, type], ): - assert unify(pattern, concrete, subs) == expected + assert unify(pattern, concrete, {}) == expected_subs def test_infer_return_type(): From 02eb5ecb313f8b5e498920ad66f6955176df73cb Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Jul 2025 16:27:57 -0400 Subject: [PATCH 11/76] types --- effectful/internals/unification.py | 73 +++++++++++++++++++++---- tests/test_internals_unification.py | 84 +++++++++++++++++++++++++++-- 2 files changed, 142 insertions(+), 15 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index c1b799bc..1c660aa9 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -4,7 +4,7 @@ import typing -def infer_return_type(bound_sig: inspect.BoundArguments) -> type | types.GenericAlias: +def infer_return_type(bound_sig: inspect.BoundArguments) -> type | types.GenericAlias | types.UnionType: """ Infer the return type of a function based on its signature and argument types. """ @@ -64,10 +64,14 @@ def infer_return_type(bound_sig: inspect.BoundArguments) -> type | types.Generic def unify( pattern: type + | typing.TypeVar | types.GenericAlias + | types.UnionType | collections.abc.Mapping | collections.abc.Sequence, concrete: type + | typing.TypeVar + | types.UnionType | types.GenericAlias | collections.abc.Mapping | collections.abc.Sequence, @@ -82,7 +86,7 @@ def unify( raise TypeError( f"Cannot unify {pattern} with {concrete} (already unified with {subs[pattern]})" ) - return {**subs, pattern: concrete} + return {**subs, **{pattern: concrete}} elif typing.get_args(pattern) and typing.get_args(concrete): if typing.get_origin(pattern) != typing.get_origin(concrete): raise TypeError(f"Cannot unify {pattern} with {concrete}") @@ -109,9 +113,55 @@ def unify( return subs -def freetypevars(typ: type | types.GenericAlias) -> set[typing.TypeVar]: +def freetypevars(typ: type | typing.TypeVar | types.GenericAlias | types.UnionType) -> set[typing.TypeVar]: """ - Return a set of free type variables in the given type. + Return a set of free type variables in the given type expression. + + This function recursively traverses a type expression to find all TypeVar + instances that appear within it. It handles both simple types and generic + type aliases with nested type arguments. TypeVars are considered "free" + when they are not bound to a specific concrete type. + + Args: + typ: The type expression to analyze. Can be a plain type (e.g., int), + a TypeVar, or a generic type alias (e.g., List[T], Dict[K, V]). + + Returns: + A set containing all TypeVar instances found in the type expression. + Returns an empty set if no TypeVars are present. + + Examples: + >>> T = typing.TypeVar('T') + >>> K = typing.TypeVar('K') + >>> V = typing.TypeVar('V') + + >>> # TypeVar returns itself + >>> freetypevars(T) + {~T} + + >>> # Generic type with one TypeVar + >>> freetypevars(list[T]) + {~T} + + >>> # Generic type with multiple TypeVars + >>> sorted(freetypevars(dict[K, V]), key=lambda x: x.__name__) + [~K, ~V] + + >>> # Nested generic types + >>> sorted(freetypevars(list[dict[K, V]]), key=lambda x: x.__name__) + [~K, ~V] + + >>> # Concrete types have no free TypeVars + >>> freetypevars(int) + set() + + >>> # Generic types with concrete arguments have no free TypeVars + >>> freetypevars(list[int]) + set() + + >>> # Mixed concrete and TypeVar arguments + >>> freetypevars(dict[str, T]) + {~T} """ if isinstance(typ, typing.TypeVar): return {typ} @@ -122,8 +172,8 @@ def freetypevars(typ: type | types.GenericAlias) -> set[typing.TypeVar]: def substitute( - typ: type | types.GenericAlias, subs: collections.abc.Mapping[typing.TypeVar, type] -) -> type | types.GenericAlias: + typ: type | types.GenericAlias | types.UnionType, subs: collections.abc.Mapping[typing.TypeVar, type] +) -> type | types.GenericAlias | types.UnionType: """ Substitute type variables in a type expression with concrete types. @@ -153,12 +203,12 @@ def substitute( >>> # Generic type substitution - >>> substitute(typing.List[T], {T: str}) - typing.List[str] + >>> substitute(list[T], {T: str}) + list[str] >>> # Nested generic substitution - >>> substitute(typing.Dict[K, typing.List[V]], {K: str, V: int}) - typing.Dict[str, typing.List[int]] + >>> substitute(dict[K, list[V]], {K: str, V: int}) + dict[str, list[int]] >>> # TypeVar not in mapping remains unchanged >>> substitute(T, {K: int}) @@ -174,6 +224,9 @@ def substitute( origin = typing.get_origin(typ) assert origin is not None, "Type must have an origin" new_args = tuple(substitute(arg, subs) for arg in typing.get_args(typ)) + # Handle Union types specially + if origin is types.UnionType: + return typing.Union[new_args] return origin[new_args] else: return typ diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index e6313aae..d26b5c27 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -9,17 +9,48 @@ ) T = typing.TypeVar("T") +K = typing.TypeVar("K") +V = typing.TypeVar("V") +U = typing.TypeVar("U") @pytest.mark.parametrize( "typ,fvs", [ - (list[T], {T}), - (dict[str, T], {T}), + # Basic cases + (T, {T}), (int, set()), + (str, set()), + # Single TypeVar in generic + (list[T], {T}), + (set[T], {T}), + (tuple[T], {T}), + # Multiple TypeVars + (dict[K, V], {K, V}), + (tuple[K, V], {K, V}), + (dict[T, T], {T}), # Same TypeVar used twice + # Nested generics with TypeVars + (list[dict[K, V]], {K, V}), + (dict[K, list[V]], {K, V}), + (list[tuple[T, U]], {T, U}), + (tuple[list[T], dict[K, V]], {T, K, V}), + # Concrete types in generics (list[int], set()), (dict[str, int], set()), + (tuple[int, str, float], set()), + # Mixed concrete and TypeVars + (dict[str, T], {T}), + (dict[K, int], {K}), + (tuple[T, int, V], {T, V}), + (list[tuple[int, T]], {T}), + # Deeply nested + (list[dict[K, list[tuple[V, T]]]], {K, V, T}), + (dict[tuple[K, V], list[dict[U, T]]], {K, V, U, T}), + # Union types (if supported) + (list[T] | dict[K, V], {T, K, V}), + (T | int, {T}), ], + ids=str ) def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): assert freetypevars(typ) == fvs @@ -28,12 +59,54 @@ def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): @pytest.mark.parametrize( "typ,subs,expected", [ + # Basic substitution + (T, {T: int}, int), + (T, {T: str}, str), + (T, {T: list[int]}, list[int]), + # TypeVar not in mapping + (T, {K: int}, T), + (T, {}, T), + # Non-TypeVar types + (int, {T: str}, int), + (str, {}, str), + (list[int], {T: str}, list[int]), + # Single TypeVar in generic (list[T], {T: int}, list[int]), + (set[T], {T: str}, set[str]), + (tuple[T], {T: float}, tuple[float]), + # Multiple TypeVars + (dict[K, V], {K: str, V: int}, dict[str, int]), + (tuple[K, V], {K: int, V: str}, tuple[int, str]), + (dict[K, V], {K: str}, dict[str, V]), # Partial substitution + # Same TypeVar used multiple times + (dict[T, T], {T: int}, dict[int, int]), + (tuple[T, T, T], {T: str}, tuple[str, str, str]), + # Nested generics + (list[dict[K, V]], {K: str, V: int}, list[dict[str, int]]), + (dict[K, list[V]], {K: int, V: str}, dict[int, list[str]]), + (list[tuple[T, U]], {T: int, U: str}, list[tuple[int, str]]), + # Mixed concrete and TypeVars (dict[str, T], {T: int}, dict[str, int]), - (int, {}, int), - (list[int], {}, list[int]), - (dict[str, int], {}, dict[str, int]), + (tuple[int, T, str], {T: float}, tuple[int, float, str]), + (list[tuple[int, T]], {T: str}, list[tuple[int, str]]), + # Deeply nested substitution + (list[dict[K, list[V]]], {K: str, V: int}, list[dict[str, list[int]]]), + (dict[tuple[K, V], list[T]], {K: int, V: str, T: float}, dict[tuple[int, str], list[float]]), + # Substituting with generic types + (T, {T: list[int]}, list[int]), + (list[T], {T: dict[str, int]}, list[dict[str, int]]), + (dict[K, V], {K: list[int], V: dict[str, float]}, dict[list[int], dict[str, float]]), + # Empty substitution + (list[T], {}, list[T]), + (dict[K, V], {}, dict[K, V]), + # Union types (if supported) + (T | int, {T: str}, str | int), + (list[T] | dict[K, V], {T: int, K: str, V: float}, list[int] | dict[str, float]), + # Irrelevant substitutions (TypeVars not in type) + (list[T], {K: int, V: str}, list[T]), + (int, {T: str, K: int}, int), ], + ids=str ) def test_substitute( typ: type, subs: typing.Mapping[typing.TypeVar, type], expected: type @@ -47,6 +120,7 @@ def test_substitute( (T, int, {T: int}), (list[T], list[int], {T: int}), ], + ids=str ) def test_unify( pattern: type, From 8a5152d9dcd7b1c3fd04c453862726f6bc969e59 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Jul 2025 16:41:11 -0400 Subject: [PATCH 12/76] add docstring with doctests for unify --- effectful/internals/unification.py | 172 +++++++++++++++++++++------- tests/test_internals_unification.py | 85 ++++++++++++-- 2 files changed, 207 insertions(+), 50 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 1c660aa9..27abffd1 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -4,7 +4,9 @@ import typing -def infer_return_type(bound_sig: inspect.BoundArguments) -> type | types.GenericAlias | types.UnionType: +def infer_return_type( + bound_sig: inspect.BoundArguments, +) -> type | types.GenericAlias | types.UnionType: """ Infer the return type of a function based on its signature and argument types. """ @@ -63,13 +65,13 @@ def infer_return_type(bound_sig: inspect.BoundArguments) -> type | types.Generic def unify( - pattern: type + typ: type | typing.TypeVar | types.GenericAlias | types.UnionType | collections.abc.Mapping | collections.abc.Sequence, - concrete: type + subtyp: type | typing.TypeVar | types.UnionType | types.GenericAlias @@ -79,100 +81,180 @@ def unify( ) -> collections.abc.Mapping[typing.TypeVar, type]: """ Unify a pattern type with a concrete type, returning a substitution map. - Raises TypeError if unification is not possible. + + This function attempts to find a substitution of type variables that makes + the pattern type (typ) equal to the concrete type (subtyp). It updates + and returns the substitution mapping, or raises TypeError if unification + is not possible. + + The function handles: + - TypeVar unification (binding type variables to concrete types) + - Generic type unification (matching origins and recursively unifying args) + - Structural unification of sequences and mappings + - Exact type matching for non-generic types + + Args: + typ: The pattern type that may contain TypeVars to be unified + subtyp: The concrete type to unify with the pattern + subs: Existing substitution mappings to be extended (not modified) + + Returns: + A new substitution mapping that includes all previous substitutions + plus any new TypeVar bindings discovered during unification. + + Raises: + TypeError: If unification is not possible (incompatible types or + conflicting TypeVar bindings) + + Examples: + >>> import typing + >>> T = typing.TypeVar('T') + >>> K = typing.TypeVar('K') + >>> V = typing.TypeVar('V') + + >>> # Simple TypeVar unification + >>> unify(T, int, {}) + {~T: } + + >>> # Generic type unification + >>> unify(list[T], list[int], {}) + {~T: } + + >>> # Multiple TypeVars + >>> unify(dict[K, V], dict[str, int], {}) + {~K: , ~V: } + + >>> # With existing substitutions + >>> unify(V, bool, {T: int}) + {~T: , ~V: } + + >>> # Nested generic unification + >>> unify(list[dict[K, V]], list[dict[str, int]], {}) + {~K: , ~V: } + + >>> # Exact type matching + >>> unify(int, int, {}) + {} + + >>> # Failed unification - incompatible types + >>> unify(list[T], dict[str, int], {}) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + TypeError: Cannot unify list[~T] with dict[str, int] + + >>> # Failed unification - conflicting TypeVar binding + >>> unify(T, str, {T: int}) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + TypeError: Cannot unify ~T with (already unified with ) + + >>> # Callable type unification + >>> unify(typing.Callable[[T], V], typing.Callable[[int], str], {}) + {~T: , ~V: } + + >>> # Sequence unification (tuples as sequences) + >>> unify((T, V), (int, str), {}) + {~T: , ~V: } """ - if isinstance(pattern, typing.TypeVar): - if pattern in subs and subs[pattern] != concrete: + if isinstance(typ, typing.TypeVar): + if typ in subs and subs[typ] != subtyp: raise TypeError( - f"Cannot unify {pattern} with {concrete} (already unified with {subs[pattern]})" + f"Cannot unify {typ} with {subtyp} (already unified with {subs[typ]})" ) - return {**subs, **{pattern: concrete}} - elif typing.get_args(pattern) and typing.get_args(concrete): - if typing.get_origin(pattern) != typing.get_origin(concrete): - raise TypeError(f"Cannot unify {pattern} with {concrete}") - return unify(typing.get_args(pattern), typing.get_args(concrete), subs) - elif isinstance(pattern, collections.abc.Mapping) and isinstance( - concrete, collections.abc.Mapping + return {**subs, **{typ: subtyp}} + elif typing.get_args(typ) and typing.get_args(subtyp): + if typing.get_origin(typ) != typing.get_origin(subtyp): + raise TypeError(f"Cannot unify {typ} with {subtyp}") + return unify(typing.get_args(typ), typing.get_args(subtyp), subs) + elif isinstance(typ, collections.abc.Mapping) and isinstance( + subtyp, collections.abc.Mapping ): - if pattern.keys() != concrete.keys(): - raise TypeError(f"Cannot unify {pattern} with {concrete}") - for key in pattern: - subs = unify(pattern[key], concrete[key], subs) + if typ.keys() != subtyp.keys(): + raise TypeError(f"Cannot unify {typ} with {subtyp}") + for key in typ: + subs = unify(typ[key], subtyp[key], subs) return subs - elif isinstance(pattern, collections.abc.Sequence) and isinstance( - concrete, collections.abc.Sequence + elif isinstance(typ, collections.abc.Sequence) and isinstance( + subtyp, collections.abc.Sequence ): - if len(pattern) != len(concrete): - raise TypeError(f"Cannot unify {pattern} with {concrete}") - for p_item, c_item in zip(pattern, concrete): + if len(typ) != len(subtyp): + raise TypeError(f"Cannot unify {typ} with {subtyp}") + for p_item, c_item in zip(typ, subtyp): subs = unify(p_item, c_item, subs) return subs else: - if pattern != concrete: - raise TypeError(f"Cannot unify {pattern} with {concrete}") + if typ != subtyp: + raise TypeError(f"Cannot unify {typ} with {subtyp}") return subs -def freetypevars(typ: type | typing.TypeVar | types.GenericAlias | types.UnionType) -> set[typing.TypeVar]: +def freetypevars( + typ: type | typing.TypeVar | types.GenericAlias | types.UnionType, +) -> set[typing.TypeVar]: """ Return a set of free type variables in the given type expression. - + This function recursively traverses a type expression to find all TypeVar instances that appear within it. It handles both simple types and generic - type aliases with nested type arguments. TypeVars are considered "free" + type aliases with nested type arguments. TypeVars are considered "free" when they are not bound to a specific concrete type. - + Args: typ: The type expression to analyze. Can be a plain type (e.g., int), a TypeVar, or a generic type alias (e.g., List[T], Dict[K, V]). - + Returns: A set containing all TypeVar instances found in the type expression. Returns an empty set if no TypeVars are present. - + Examples: >>> T = typing.TypeVar('T') >>> K = typing.TypeVar('K') >>> V = typing.TypeVar('V') - + >>> # TypeVar returns itself >>> freetypevars(T) {~T} - + >>> # Generic type with one TypeVar >>> freetypevars(list[T]) {~T} - + >>> # Generic type with multiple TypeVars >>> sorted(freetypevars(dict[K, V]), key=lambda x: x.__name__) [~K, ~V] - + >>> # Nested generic types >>> sorted(freetypevars(list[dict[K, V]]), key=lambda x: x.__name__) [~K, ~V] - + >>> # Concrete types have no free TypeVars >>> freetypevars(int) set() - + >>> # Generic types with concrete arguments have no free TypeVars >>> freetypevars(list[int]) set() - + >>> # Mixed concrete and TypeVar arguments >>> freetypevars(dict[str, T]) {~T} """ if isinstance(typ, typing.TypeVar): return {typ} + elif isinstance(typ, list | tuple): + # Handle plain lists and tuples (not generic aliases) + return set.union(*(freetypevars(item) for item in typ)) if typ else set() elif typing.get_args(typ): + # Handle generic aliases return set.union(*(freetypevars(arg) for arg in typing.get_args(typ))) else: return set() def substitute( - typ: type | types.GenericAlias | types.UnionType, subs: collections.abc.Mapping[typing.TypeVar, type] + typ: type | types.GenericAlias | types.UnionType, + subs: collections.abc.Mapping[typing.TypeVar, type], ) -> type | types.GenericAlias | types.UnionType: """ Substitute type variables in a type expression with concrete types. @@ -220,13 +302,23 @@ def substitute( """ if isinstance(typ, typing.TypeVar): return subs.get(typ, typ) + elif isinstance(typ, list): + # Handle plain lists (e.g., in Callable's parameter list) + return [substitute(item, subs) for item in typ] + elif isinstance(typ, tuple): + # Handle plain tuples + return tuple(substitute(item, subs) for item in typ) elif typing.get_args(typ): origin = typing.get_origin(typ) assert origin is not None, "Type must have an origin" new_args = tuple(substitute(arg, subs) for arg in typing.get_args(typ)) # Handle Union types specially if origin is types.UnionType: - return typing.Union[new_args] + return typing.Union[new_args] # noqa + # Handle Callable types to preserve typing.Callable + elif origin is collections.abc.Callable: + # Use typing.Callable to get better repr + return typing.Callable[new_args[0], new_args[1]] return origin[new_args] else: return typ diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index d26b5c27..5eaf49b3 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -49,8 +49,22 @@ # Union types (if supported) (list[T] | dict[K, V], {T, K, V}), (T | int, {T}), + # Callable types + (typing.Callable[[T], V], {T, V}), + (typing.Callable[[int, T], T], {T}), + (typing.Callable[[], T], {T}), + (typing.Callable[[T, U], V], {T, U, V}), + (typing.Callable[[int], int], set()), + (typing.Callable[[T], list[T]], {T}), + (typing.Callable[[dict[K, V]], tuple[K, V]], {K, V}), + # Nested Callable + (typing.Callable[[T], typing.Callable[[U], V]], {T, U, V}), + (list[typing.Callable[[T], V]], {T, V}), + (dict[K, typing.Callable[[T], V]], {K, T, V}), + # ParamSpec and TypeVarTuple (if needed later) + # (typing.Callable[typing.ParamSpec("P"), T], {T}), # Would need to handle ParamSpec ], - ids=str + ids=str, ) def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): assert freetypevars(typ) == fvs @@ -91,22 +105,73 @@ def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): (list[tuple[int, T]], {T: str}, list[tuple[int, str]]), # Deeply nested substitution (list[dict[K, list[V]]], {K: str, V: int}, list[dict[str, list[int]]]), - (dict[tuple[K, V], list[T]], {K: int, V: str, T: float}, dict[tuple[int, str], list[float]]), + ( + dict[tuple[K, V], list[T]], + {K: int, V: str, T: float}, + dict[tuple[int, str], list[float]], + ), # Substituting with generic types (T, {T: list[int]}, list[int]), (list[T], {T: dict[str, int]}, list[dict[str, int]]), - (dict[K, V], {K: list[int], V: dict[str, float]}, dict[list[int], dict[str, float]]), + ( + dict[K, V], + {K: list[int], V: dict[str, float]}, + dict[list[int], dict[str, float]], + ), # Empty substitution (list[T], {}, list[T]), (dict[K, V], {}, dict[K, V]), # Union types (if supported) (T | int, {T: str}, str | int), - (list[T] | dict[K, V], {T: int, K: str, V: float}, list[int] | dict[str, float]), + ( + list[T] | dict[K, V], + {T: int, K: str, V: float}, + list[int] | dict[str, float], + ), # Irrelevant substitutions (TypeVars not in type) (list[T], {K: int, V: str}, list[T]), (int, {T: str, K: int}, int), + # Callable types + (typing.Callable[[T], V], {T: int, V: str}, typing.Callable[[int], str]), + (typing.Callable[[int, T], T], {T: str}, typing.Callable[[int, str], str]), + (typing.Callable[[], T], {T: float}, typing.Callable[[], float]), + ( + typing.Callable[[T, U], V], + {T: int, U: str, V: bool}, + typing.Callable[[int, str], bool], + ), + (typing.Callable[[int], int], {T: str}, typing.Callable[[int], int]), + (typing.Callable[[T], list[T]], {T: int}, typing.Callable[[int], list[int]]), + ( + typing.Callable[[dict[K, V]], tuple[K, V]], + {K: str, V: int}, + typing.Callable[[dict[str, int]], tuple[str, int]], + ), + # Nested Callable + ( + typing.Callable[[T], typing.Callable[[U], V]], + {T: int, U: str, V: bool}, + typing.Callable[[int], typing.Callable[[str], bool]], + ), + ( + list[typing.Callable[[T], V]], + {T: int, V: str}, + list[typing.Callable[[int], str]], + ), + ( + dict[K, typing.Callable[[T], V]], + {K: str, T: int, V: float}, + dict[str, typing.Callable[[int], float]], + ), + # Partial substitution with Callable + (typing.Callable[[T, U], V], {T: int}, typing.Callable[[int, U], V]), + ( + typing.Callable[[T], dict[K, V]], + {T: int, K: str}, + typing.Callable[[int], dict[str, V]], + ), ], - ids=str + ids=str, ) def test_substitute( typ: type, subs: typing.Mapping[typing.TypeVar, type], expected: type @@ -115,19 +180,19 @@ def test_substitute( @pytest.mark.parametrize( - "pattern,concrete,expected_subs", + "typ,subtyp,expected_subs", [ (T, int, {T: int}), (list[T], list[int], {T: int}), ], - ids=str + ids=str, ) def test_unify( - pattern: type, - concrete: type, + typ: type, + subtyp: type, expected_subs: typing.Mapping[typing.TypeVar, type], ): - assert unify(pattern, concrete, {}) == expected_subs + assert unify(typ, subtyp, {}) == expected_subs def test_infer_return_type(): From 13877f22a036f838d443e6ea61996d2dda1bc459 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Jul 2025 17:06:55 -0400 Subject: [PATCH 13/76] more test cases --- effectful/internals/unification.py | 38 +++++++- tests/test_internals_unification.py | 145 +++++++++++++++++++++++++++- 2 files changed, 173 insertions(+), 10 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 27abffd1..5ef38550 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -53,10 +53,22 @@ def infer_return_type( # Build substitution map subs: collections.abc.Mapping[typing.TypeVar, type] = {} for name in sig.parameters: - subs = unify(sig.parameters[name].annotation, bound_sig.arguments[name], subs) + typ = sig.parameters[name].annotation + if typing.get_origin(typ) is typing.Annotated: + typ = typing.get_args(typ)[0] + + subtyp = bound_sig.arguments[name] + if typing.get_origin(subtyp) is typing.Annotated: + subtyp = typing.get_args(subtyp)[0] + + subs = unify(typ, subtyp, subs) # Apply substitutions to return type - result_type = substitute(sig.return_annotation, subs) + return_annotation = sig.return_annotation + if typing.get_origin(return_annotation) is typing.Annotated: + return_annotation = typing.get_args(return_annotation)[0] + + result_type = substitute(return_annotation, subs) if freetypevars(result_type): raise TypeError( "Return type cannot have free type variables after substitution" @@ -156,14 +168,30 @@ def unify( >>> unify((T, V), (int, str), {}) {~T: , ~V: } """ - if isinstance(typ, typing.TypeVar): + if typing.get_origin(typ) is typing.Annotated: + # Handle Annotated types by extracting the base type + return unify(typing.get_args(typ)[0], subtyp, subs) + elif typing.get_origin(subtyp) is typing.Annotated: + # Handle Annotated types by extracting the base type + return unify(typ, typing.get_args(subtyp)[0], subs) + elif isinstance(typ, typing.TypeVar): if typ in subs and subs[typ] != subtyp: raise TypeError( f"Cannot unify {typ} with {subtyp} (already unified with {subs[typ]})" ) return {**subs, **{typ: subtyp}} elif typing.get_args(typ) and typing.get_args(subtyp): - if typing.get_origin(typ) != typing.get_origin(subtyp): + typ_origin = typing.get_origin(typ) + subtyp_origin = typing.get_origin(subtyp) + + # Handle Union types - both typing.Union and types.UnionType are compatible + if typ_origin in (typing.Union, types.UnionType) and subtyp_origin in ( + typing.Union, + types.UnionType, + ): + return unify(typing.get_args(typ), typing.get_args(subtyp), subs) + + if typ_origin != subtyp_origin: raise TypeError(f"Cannot unify {typ} with {subtyp}") return unify(typing.get_args(typ), typing.get_args(subtyp), subs) elif isinstance(typ, collections.abc.Mapping) and isinstance( @@ -183,7 +211,7 @@ def unify( subs = unify(p_item, c_item, subs) return subs else: - if typ != subtyp: + if not issubclass(typ, subtyp): raise TypeError(f"Cannot unify {typ} with {subtyp}") return subs diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 5eaf49b3..d085f1a9 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -180,19 +180,154 @@ def test_substitute( @pytest.mark.parametrize( - "typ,subtyp,expected_subs", + "typ,subtyp,initial_subs,expected_subs", [ - (T, int, {T: int}), - (list[T], list[int], {T: int}), + # Basic TypeVar unification + (T, int, {}, {T: int}), + (T, str, {}, {T: str}), + (T, list[int], {}, {T: list[int]}), + # With existing substitutions + (V, bool, {T: int}, {T: int, V: bool}), + (K, str, {T: int, V: bool}, {T: int, V: bool, K: str}), + # Generic type unification + (list[T], list[int], {}, {T: int}), + (dict[K, V], dict[str, int], {}, {K: str, V: int}), + (tuple[T, U], tuple[int, str], {}, {T: int, U: str}), + (set[T], set[float], {}, {T: float}), + # Same TypeVar used multiple times + (dict[T, T], dict[int, int], {}, {T: int}), + (tuple[T, T, T], tuple[str, str, str], {}, {T: str}), + # Nested generic unification + (list[dict[K, V]], list[dict[str, int]], {}, {K: str, V: int}), + (dict[K, list[V]], dict[int, list[str]], {}, {K: int, V: str}), + (list[tuple[T, U]], list[tuple[bool, float]], {}, {T: bool, U: float}), + # Deeply nested + (list[dict[K, list[V]]], list[dict[str, list[int]]], {}, {K: str, V: int}), + ( + dict[tuple[K, V], list[T]], + dict[tuple[int, str], list[bool]], + {}, + {K: int, V: str, T: bool}, + ), + # Mixed concrete and TypeVars + (dict[str, T], dict[str, int], {}, {T: int}), + (tuple[int, T, str], tuple[int, float, str], {}, {T: float}), + (list[tuple[int, T]], list[tuple[int, str]], {}, {T: str}), + # Exact type matching (no TypeVars) + (int, int, {}, {}), + (str, str, {}, {}), + (list[int], list[int], {}, {}), + (dict[str, int], dict[str, int], {}, {}), + # Callable type unification + (typing.Callable[[T], V], typing.Callable[[int], str], {}, {T: int, V: str}), + ( + typing.Callable[[T, U], V], + typing.Callable[[int, str], bool], + {}, + {T: int, U: str, V: bool}, + ), + (typing.Callable[[], T], typing.Callable[[], float], {}, {T: float}), + ( + typing.Callable[[T], list[T]], + typing.Callable[[int], list[int]], + {}, + {T: int}, + ), + # Nested Callable + ( + typing.Callable[[T], typing.Callable[[U], V]], + typing.Callable[[int], typing.Callable[[str], bool]], + {}, + {T: int, U: str, V: bool}, + ), + # Union types - basic case where both sides are unions + # Note: Complex union unification like (T | int, str | int) -> {T: str} + # would require more sophisticated logic + # Sequence unification (tuples as sequences) + ((T, V), (int, str), {}, {T: int, V: str}), + ([T, V], [int, str], {}, {T: int, V: str}), + # Complex combinations + ( + dict[K, typing.Callable[[T], V]], + dict[str, typing.Callable[[int], bool]], + {}, + {K: str, T: int, V: bool}, + ), ], ids=str, ) -def test_unify( +def test_unify_success( typ: type, subtyp: type, + initial_subs: typing.Mapping[typing.TypeVar, type], expected_subs: typing.Mapping[typing.TypeVar, type], ): - assert unify(typ, subtyp, {}) == expected_subs + assert unify(typ, subtyp, initial_subs) == expected_subs + + +@pytest.mark.parametrize( + "typ,subtyp,initial_subs,error_pattern", + [ + # Incompatible types + ( + list[T], + dict[str, int], + {}, + "Cannot unify list\\[~T\\] with dict\\[str, int\\]", + ), + (int, str, {}, "Cannot unify with "), + (list[int], list[str], {}, "Cannot unify with "), + # Conflicting TypeVar bindings + ( + T, + str, + {T: int}, + "Cannot unify ~T with \\(already unified with \\)", + ), + ( + list[T], + list[str], + {T: int}, + "Cannot unify ~T with \\(already unified with \\)", + ), + # Mismatched generic types + (list[T], set[int], {}, "Cannot unify list\\[~T\\] with set\\[int\\]"), + (dict[K, V], list[int], {}, "Cannot unify dict\\[~K, ~V\\] with list\\[int\\]"), + # Same TypeVar with different values + ( + dict[T, T], + dict[int, str], + {}, + "Cannot unify ~T with \\(already unified with \\)", + ), + ( + tuple[T, T], + tuple[int, str], + {}, + "Cannot unify ~T with \\(already unified with \\)", + ), + # Mismatched arities + (tuple[T, U], tuple[int, str, bool], {}, "Cannot unify"), + ( + typing.Callable[[T], V], + typing.Callable[[int, str], bool], + {}, + "Cannot unify", + ), + # Sequence length mismatch + ((T, V), (int,), {}, "Cannot unify"), + ([T, V], [int, str, bool], {}, "Cannot unify"), + ], + ids=str, +) +def test_unify_failure( + typ: type, + subtyp: type, + initial_subs: typing.Mapping[typing.TypeVar, type], + error_pattern: str, +): + with pytest.raises(TypeError, match=error_pattern): + unify(typ, subtyp, initial_subs) def test_infer_return_type(): From 0dbd307eb06caae9fd15fdbd1547410a4c669031 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Jul 2025 17:07:33 -0400 Subject: [PATCH 14/76] case --- effectful/internals/unification.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 5ef38550..c5ca72a5 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -330,12 +330,9 @@ def substitute( """ if isinstance(typ, typing.TypeVar): return subs.get(typ, typ) - elif isinstance(typ, list): - # Handle plain lists (e.g., in Callable's parameter list) - return [substitute(item, subs) for item in typ] - elif isinstance(typ, tuple): - # Handle plain tuples - return tuple(substitute(item, subs) for item in typ) + elif isinstance(typ, list | tuple): + # Handle plain lists/sequences (e.g., in Callable's parameter list) + return type(typ)(substitute(item, subs) for item in typ) elif typing.get_args(typ): origin = typing.get_origin(typ) assert origin is not None, "Type must have an origin" From e675bf601c4dda54dc8b78a610ad9f207a3e5f3c Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 3 Jul 2025 17:48:24 -0400 Subject: [PATCH 15/76] docstring and test cases for infer_return_type --- effectful/internals/unification.py | 88 ++++++- tests/test_internals_unification.py | 378 ++++++++++++++++++++++++---- 2 files changed, 411 insertions(+), 55 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index c5ca72a5..640edb18 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -9,6 +9,74 @@ def infer_return_type( ) -> type | types.GenericAlias | types.UnionType: """ Infer the return type of a function based on its signature and argument types. + + This function takes a BoundArguments object (created by binding concrete argument + types to a function signature) and infers what the return type should be by: + 1. Finding all TypeVars in the function's parameter and return annotations + 2. Unifying the parameter type annotations with the concrete argument types + 3. Applying the resulting TypeVar substitutions to the return type annotation + + The function ensures that all type variables in the return type can be inferred + from the parameter types (no unbound type variables in the return). + + Args: + bound_sig: A BoundArguments object obtained by calling + inspect.signature(func).bind(*arg_types, **kwarg_types) + where arg_types and kwarg_types are concrete types + + Returns: + The inferred return type with all TypeVars substituted with concrete types + + Raises: + TypeError: If the function lacks required type annotations, has unbound + type variables in the return type, or if unification fails + NotImplementedError: If the function uses variadic parameters (*args, **kwargs), + collection types as parameters, or parameters with + free type variables + + Examples: + >>> import inspect + >>> import typing + >>> T = typing.TypeVar('T') + >>> K = typing.TypeVar('K') + >>> V = typing.TypeVar('V') + + >>> # Simple generic function + >>> def identity(x: T) -> T: ... + >>> sig = inspect.signature(identity) + >>> bound = sig.bind(int) + >>> infer_return_type(bound) + + + >>> # Function with multiple TypeVars + >>> def make_dict(key: K, value: V) -> dict[K, V]: ... + >>> sig = inspect.signature(make_dict) + >>> bound = sig.bind(str, int) + >>> infer_return_type(bound) + dict[str, int] + + >>> # Function with nested generics + >>> def wrap_in_list(x: T) -> list[T]: ... + >>> sig = inspect.signature(wrap_in_list) + >>> bound = sig.bind(bool) + >>> infer_return_type(bound) + list[bool] + + >>> # Function with no TypeVars + >>> def get_int() -> int: ... + >>> sig = inspect.signature(get_int) + >>> bound = sig.bind() + >>> infer_return_type(bound) + + + >>> # Error: unbound type variable in return + >>> def bad_func(x: T) -> tuple[T, K]: ... # K not in parameters + >>> sig = inspect.signature(bad_func) + >>> bound = sig.bind(int) + >>> infer_return_type(bound) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + TypeError: unbound type variables in return type """ bound_sig.apply_defaults() sig: inspect.Signature = bound_sig.signature @@ -21,11 +89,15 @@ def infer_return_type( raise TypeError("All parameters must have type annotations") result_fvs: set[typing.TypeVar] = freetypevars(sig.return_annotation) - pattern_fvs: set[typing.TypeVar] = set.union( - *(freetypevars(p.annotation) for p in sig.parameters.values()), + pattern_fvs: set[typing.TypeVar] = ( + set.union(*(freetypevars(p.annotation) for p in sig.parameters.values())) + if sig.parameters + else set() ) - concrete_fvs: set[typing.TypeVar] = set.union( - *(freetypevars(arg) for arg in bound_sig.arguments.values()), + concrete_fvs: set[typing.TypeVar] = ( + set.union(*(freetypevars(arg) for arg in bound_sig.arguments.values())) + if bound_sig.arguments + else set() ) if (result_fvs | pattern_fvs) & concrete_fvs: raise TypeError( @@ -161,7 +233,7 @@ def unify( TypeError: Cannot unify ~T with (already unified with ) >>> # Callable type unification - >>> unify(typing.Callable[[T], V], typing.Callable[[int], str], {}) + >>> unify(collections.abc.Callable[[T], V], collections.abc.Callable[[int], str], {}) {~T: , ~V: } >>> # Sequence unification (tuples as sequences) @@ -270,6 +342,8 @@ def freetypevars( """ if isinstance(typ, typing.TypeVar): return {typ} + elif typing.get_origin(typ) is typing.Annotated: + return freetypevars(typing.get_args(typ)[0]) elif isinstance(typ, list | tuple): # Handle plain lists and tuples (not generic aliases) return set.union(*(freetypevars(item) for item in typ)) if typ else set() @@ -340,10 +414,6 @@ def substitute( # Handle Union types specially if origin is types.UnionType: return typing.Union[new_args] # noqa - # Handle Callable types to preserve typing.Callable - elif origin is collections.abc.Callable: - # Use typing.Callable to get better repr - return typing.Callable[new_args[0], new_args[1]] return origin[new_args] else: return typ diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index d085f1a9..2ea4a8be 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -1,9 +1,12 @@ +import collections.abc +import inspect import typing import pytest from effectful.internals.unification import ( freetypevars, + infer_return_type, substitute, unify, ) @@ -50,19 +53,19 @@ (list[T] | dict[K, V], {T, K, V}), (T | int, {T}), # Callable types - (typing.Callable[[T], V], {T, V}), - (typing.Callable[[int, T], T], {T}), - (typing.Callable[[], T], {T}), - (typing.Callable[[T, U], V], {T, U, V}), - (typing.Callable[[int], int], set()), - (typing.Callable[[T], list[T]], {T}), - (typing.Callable[[dict[K, V]], tuple[K, V]], {K, V}), + (collections.abc.Callable[[T], V], {T, V}), + (collections.abc.Callable[[int, T], T], {T}), + (collections.abc.Callable[[], T], {T}), + (collections.abc.Callable[[T, U], V], {T, U, V}), + (collections.abc.Callable[[int], int], set()), + (collections.abc.Callable[[T], list[T]], {T}), + (collections.abc.Callable[[dict[K, V]], tuple[K, V]], {K, V}), # Nested Callable - (typing.Callable[[T], typing.Callable[[U], V]], {T, U, V}), - (list[typing.Callable[[T], V]], {T, V}), - (dict[K, typing.Callable[[T], V]], {K, T, V}), + (collections.abc.Callable[[T], collections.abc.Callable[[U], V]], {T, U, V}), + (list[collections.abc.Callable[[T], V]], {T, V}), + (dict[K, collections.abc.Callable[[T], V]], {K, T, V}), # ParamSpec and TypeVarTuple (if needed later) - # (typing.Callable[typing.ParamSpec("P"), T], {T}), # Would need to handle ParamSpec + # (collections.abc.Callable[typing.ParamSpec("P"), T], {T}), # Would need to handle ParamSpec ], ids=str, ) @@ -132,43 +135,67 @@ def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): (list[T], {K: int, V: str}, list[T]), (int, {T: str, K: int}, int), # Callable types - (typing.Callable[[T], V], {T: int, V: str}, typing.Callable[[int], str]), - (typing.Callable[[int, T], T], {T: str}, typing.Callable[[int, str], str]), - (typing.Callable[[], T], {T: float}, typing.Callable[[], float]), ( - typing.Callable[[T, U], V], + collections.abc.Callable[[T], V], + {T: int, V: str}, + collections.abc.Callable[[int], str], + ), + ( + collections.abc.Callable[[int, T], T], + {T: str}, + collections.abc.Callable[[int, str], str], + ), + ( + collections.abc.Callable[[], T], + {T: float}, + collections.abc.Callable[[], float], + ), + ( + collections.abc.Callable[[T, U], V], {T: int, U: str, V: bool}, - typing.Callable[[int, str], bool], + collections.abc.Callable[[int, str], bool], ), - (typing.Callable[[int], int], {T: str}, typing.Callable[[int], int]), - (typing.Callable[[T], list[T]], {T: int}, typing.Callable[[int], list[int]]), ( - typing.Callable[[dict[K, V]], tuple[K, V]], + collections.abc.Callable[[int], int], + {T: str}, + collections.abc.Callable[[int], int], + ), + ( + collections.abc.Callable[[T], list[T]], + {T: int}, + collections.abc.Callable[[int], list[int]], + ), + ( + collections.abc.Callable[[dict[K, V]], tuple[K, V]], {K: str, V: int}, - typing.Callable[[dict[str, int]], tuple[str, int]], + collections.abc.Callable[[dict[str, int]], tuple[str, int]], ), # Nested Callable ( - typing.Callable[[T], typing.Callable[[U], V]], + collections.abc.Callable[[T], collections.abc.Callable[[U], V]], {T: int, U: str, V: bool}, - typing.Callable[[int], typing.Callable[[str], bool]], + collections.abc.Callable[[int], collections.abc.Callable[[str], bool]], ), ( - list[typing.Callable[[T], V]], + list[collections.abc.Callable[[T], V]], {T: int, V: str}, - list[typing.Callable[[int], str]], + list[collections.abc.Callable[[int], str]], ), ( - dict[K, typing.Callable[[T], V]], + dict[K, collections.abc.Callable[[T], V]], {K: str, T: int, V: float}, - dict[str, typing.Callable[[int], float]], + dict[str, collections.abc.Callable[[int], float]], ), # Partial substitution with Callable - (typing.Callable[[T, U], V], {T: int}, typing.Callable[[int, U], V]), ( - typing.Callable[[T], dict[K, V]], + collections.abc.Callable[[T, U], V], + {T: int}, + collections.abc.Callable[[int, U], V], + ), + ( + collections.abc.Callable[[T], dict[K, V]], {T: int, K: str}, - typing.Callable[[int], dict[str, V]], + collections.abc.Callable[[int], dict[str, V]], ), ], ids=str, @@ -219,37 +246,56 @@ def test_substitute( (list[int], list[int], {}, {}), (dict[str, int], dict[str, int], {}, {}), # Callable type unification - (typing.Callable[[T], V], typing.Callable[[int], str], {}, {T: int, V: str}), ( - typing.Callable[[T, U], V], - typing.Callable[[int, str], bool], + collections.abc.Callable[[T], V], + collections.abc.Callable[[int], str], + {}, + {T: int, V: str}, + ), + ( + collections.abc.Callable[[T, U], V], + collections.abc.Callable[[int, str], bool], {}, {T: int, U: str, V: bool}, ), - (typing.Callable[[], T], typing.Callable[[], float], {}, {T: float}), ( - typing.Callable[[T], list[T]], - typing.Callable[[int], list[int]], + collections.abc.Callable[[], T], + collections.abc.Callable[[], float], + {}, + {T: float}, + ), + ( + collections.abc.Callable[[T], list[T]], + collections.abc.Callable[[int], list[int]], {}, {T: int}, ), # Nested Callable ( - typing.Callable[[T], typing.Callable[[U], V]], - typing.Callable[[int], typing.Callable[[str], bool]], + collections.abc.Callable[[T], collections.abc.Callable[[U], V]], + collections.abc.Callable[[int], collections.abc.Callable[[str], bool]], {}, {T: int, U: str, V: bool}, ), - # Union types - basic case where both sides are unions - # Note: Complex union unification like (T | int, str | int) -> {T: str} - # would require more sophisticated logic + # Union types - basic element-wise unification (current implementation) + # Note: Current unify treats union args as sequences, not true union logic + ( + T | V, + int | str, + {}, + {T: int, V: str}, + ), # Element-wise unification of TypeVars + (T | V, int | str, {}, {T: int, V: str}), # typing.Union syntax + # Simple union compatibility - TypeVar gets unified with itself + (T | int, T | int, {}, {T: T}), # Identical unions - T unifies with T + (T | int, T | int, {}, {T: T}), # Identical typing.Union # Sequence unification (tuples as sequences) ((T, V), (int, str), {}, {T: int, V: str}), ([T, V], [int, str], {}, {T: int, V: str}), # Complex combinations ( - dict[K, typing.Callable[[T], V]], - dict[str, typing.Callable[[int], bool]], + dict[K, collections.abc.Callable[[T], V]], + dict[str, collections.abc.Callable[[int], bool]], {}, {K: str, T: int, V: bool}, ), @@ -309,14 +355,35 @@ def test_unify_success( # Mismatched arities (tuple[T, U], tuple[int, str, bool], {}, "Cannot unify"), ( - typing.Callable[[T], V], - typing.Callable[[int, str], bool], + collections.abc.Callable[[T], V], + collections.abc.Callable[[int, str], bool], {}, "Cannot unify", ), # Sequence length mismatch ((T, V), (int,), {}, "Cannot unify"), ([T, V], [int, str, bool], {}, "Cannot unify"), + # Union type failures - element-wise unification failures + ( + T | V, + int | str, + {T: float}, + "Cannot unify ~T with ", + ), # TypeVar conflict + ( + T | int, + V | str, + {}, + "Cannot unify with ", + ), # Concrete type mismatch + ( + T | int, + V | str, + {}, + "Cannot unify with ", + ), # typing.Union mismatch + # Union with different arities + (T | V, int | str | bool, {}, "Cannot unify"), # Different union sizes ], ids=str, ) @@ -330,5 +397,224 @@ def test_unify_failure( unify(typ, subtyp, initial_subs) -def test_infer_return_type(): - pass # TODO fill this in +# Test functions with various type patterns +def identity(x: T) -> T: + return x + + +def make_pair(x: T, y: V) -> tuple[T, V]: + return (x, y) + + +def wrap_in_list(x: T) -> list[T]: + return [x] + + +def get_first(items: list[T]) -> T: + return items[0] + + +def getitem_mapping(mapping: collections.abc.Mapping[K, V], key: K) -> V: + return mapping[key] + + +def dict_values(d: dict[K, V]) -> list[V]: + return list(d.values()) + + +def process_callable(func: collections.abc.Callable[[T], V], arg: T) -> V: + return func(arg) + + +def chain_callables( + f: collections.abc.Callable[[T], U], g: collections.abc.Callable[[U], V] +) -> collections.abc.Callable[[T], V]: + def result(x: T) -> V: + return g(f(x)) + + return result + + +def constant_func() -> int: + return 42 + + +def multi_generic(a: T, b: list[T], c: dict[K, V]) -> tuple[T, K, V]: + return (a, next(iter(c.keys())), next(iter(c.values()))) + + +def same_type_twice(x: T, y: T) -> T: + return x if len(str(x)) > len(str(y)) else y + + +def nested_generic(x: T) -> dict[str, list[T]]: + return {"items": [x]} + + +@pytest.mark.parametrize( + "func,args,kwargs,expected_return_type", + [ + # Simple generic functions + (identity, (int,), {}, int), + (identity, (str,), {}, str), + (identity, (list[int],), {}, list[int]), + # Multiple TypeVars + (make_pair, (int, str), {}, tuple[int, str]), + (make_pair, (bool, list[float]), {}, tuple[bool, list[float]]), + # Generic collections + (wrap_in_list, (int,), {}, list[int]), + (wrap_in_list, (dict[str, bool],), {}, list[dict[str, bool]]), + (get_first, (list[str],), {}, str), + (get_first, (list[tuple[int, float]],), {}, tuple[int, float]), + (getitem_mapping, (collections.abc.Mapping[str, int], str), {}, int), + ( + getitem_mapping, + (collections.abc.Mapping[bool, list[str]], bool), + {}, + list[str], + ), + # Dict operations + (dict_values, (dict[str, int],), {}, list[int]), + (dict_values, (dict[bool, list[str]],), {}, list[list[str]]), + # Callable types + (process_callable, (collections.abc.Callable[[int], str], int), {}, str), + ( + process_callable, + (collections.abc.Callable[[list[int]], bool], list[int]), + {}, + bool, + ), + # Complex callable return + ( + chain_callables, + ( + collections.abc.Callable[[int], str], + collections.abc.Callable[[str], bool], + ), + {}, + collections.abc.Callable[[int], bool], + ), + # No generics + (constant_func, (), {}, int), + # Mixed generics + (multi_generic, (int, list[int], dict[str, bool]), {}, tuple[int, str, bool]), + ( + multi_generic, + (float, list[float], dict[bool, list[str]]), + {}, + tuple[float, bool, list[str]], + ), + # Same TypeVar used multiple times + (same_type_twice, (int, int), {}, int), + (same_type_twice, (str, str), {}, str), + # Nested generics + (nested_generic, (int,), {}, dict[str, list[int]]), + ( + nested_generic, + (collections.abc.Callable[[str], bool],), + {}, + dict[str, list[collections.abc.Callable[[str], bool]]], + ), + # Keyword arguments + (make_pair, (), {"x": int, "y": str}, tuple[int, str]), + ( + multi_generic, + (), + {"a": bool, "b": list[bool], "c": dict[int, str]}, + tuple[bool, int, str], + ), + ], + ids=str, +) +def test_infer_return_type_success( + func: collections.abc.Callable, + args: tuple, + kwargs: dict, + expected_return_type: type, +): + sig = inspect.signature(func) + bound = sig.bind(*args, **kwargs) + result = infer_return_type(bound) + assert result == expected_return_type + + +# Error cases +def unbound_typevar_func(x: T) -> tuple[T, V]: # V not in parameters + return (x, "error") # type: ignore + + +def no_return_annotation(x: T): # No return annotation + return x + + +def no_param_annotation(x) -> T: # No parameter annotation + return x # type: ignore + + +def variadic_args_func(*args: T) -> T: # Variadic args not supported + return args[0] + + +def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported + return next(iter(kwargs.values())) + + +@pytest.mark.parametrize( + "func,args,kwargs,error_type,error_pattern", + [ + # Unbound type variable in return + ( + unbound_typevar_func, + (int,), + {}, + TypeError, + "unbound type variables in return type", + ), + # Missing annotations + ( + no_return_annotation, + (int,), + {}, + TypeError, + "Function must have a return type annotation", + ), + ( + no_param_annotation, + (int,), + {}, + TypeError, + "All parameters must have type annotations", + ), + # Type mismatch - trying to unify incompatible types + (same_type_twice, (int, str), {}, TypeError, "Cannot unify"), + ], + ids=str, +) +def test_infer_return_type_failure( + func: collections.abc.Callable, + args: tuple, + kwargs: dict, + error_type: type[Exception], + error_pattern: str, +): + sig = inspect.signature(func) + bound = sig.bind(*args, **kwargs) + with pytest.raises(error_type, match=error_pattern): + infer_return_type(bound) + + +# Variadic functions - not implemented yet, marked as expected failures +@pytest.mark.xfail(reason="Variadic args not implemented") +def test_infer_return_type_variadic_args(): + sig = inspect.signature(variadic_args_func) + bound = sig.bind(int) + result = infer_return_type(bound) + assert result == int + + +@pytest.mark.xfail(reason="Variadic kwargs not implemented") +def test_infer_return_type_variadic_kwargs(): + sig = inspect.signature(variadic_kwargs_func) + bound = sig.bind(x=int) + result = infer_return_type(bound) + assert result == int From 90dd9fb05dcb0f560347fe9b753c68374eec4a4c Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 10:34:19 -0400 Subject: [PATCH 16/76] canonicalize --- effectful/internals/unification.py | 74 ++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 14 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 640edb18..3679689f 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -125,22 +125,14 @@ def infer_return_type( # Build substitution map subs: collections.abc.Mapping[typing.TypeVar, type] = {} for name in sig.parameters: - typ = sig.parameters[name].annotation - if typing.get_origin(typ) is typing.Annotated: - typ = typing.get_args(typ)[0] - - subtyp = bound_sig.arguments[name] - if typing.get_origin(subtyp) is typing.Annotated: - subtyp = typing.get_args(subtyp)[0] - - subs = unify(typ, subtyp, subs) + subs = unify( + canonicalize(sig.parameters[name].annotation), + canonicalize(bound_sig.arguments[name]), + subs, + ) # Apply substitutions to return type - return_annotation = sig.return_annotation - if typing.get_origin(return_annotation) is typing.Annotated: - return_annotation = typing.get_args(return_annotation)[0] - - result_type = substitute(return_annotation, subs) + result_type = substitute(canonicalize(sig.return_annotation), subs) if freetypevars(result_type): raise TypeError( "Return type cannot have free type variables after substitution" @@ -288,6 +280,60 @@ def unify( return subs +def canonicalize( + typ: type | typing.TypeVar | types.GenericAlias | types.UnionType, +) -> type: + """ + Return a canonical form of the given type expression. + + This function normalizes the type by removing Annotated wrappers and + ensuring that generic types are represented in their canonical form. + It does not modify TypeVars or Union types, but ensures that generic + aliases are returned in a consistent format. + + Args: + typ: The type expression to canonicalize. + + Returns: + A canonicalized version of the input type expression. + + Examples: + >>> T = typing.TypeVar('T') + >>> canonicalize(typing.List[T]) + list[~T] + >>> canonicalize(typing.Annotated[int, "example"]) + + """ + if typing.get_origin(typ) is typing.Annotated: + return canonicalize(typing.get_args(typ)[0]) + elif typing.get_origin(typ) in {typing.Union, types.UnionType}: + t = canonicalize(typing.get_args(typ)[0]) + for arg in typing.get_args(typ)[1:]: + t = t | canonicalize(arg) + return t + elif isinstance(typ, typing._GenericAlias | types.GenericAlias): # type: ignore + # Handle generic types + origin = canonicalize(typing.get_origin(typ)) + assert origin is not None, "Type must have an origin" + return origin[tuple(canonicalize(a) for a in typing.get_args(typ))] + elif isinstance(typ, collections.abc.Sequence): + return tuple(canonicalize(item) for item in typ) + elif typ is typing.Callable: + return collections.abc.Callable + elif typ is typing.Any: + return object + elif typ is list: + return list + elif typ is dict: + return dict + elif typ is set: + return set + elif typ is tuple: + return tuple + else: + return typ + + def freetypevars( typ: type | typing.TypeVar | types.GenericAlias | types.UnionType, ) -> set[typing.TypeVar]: From c95ea014f216db36ff01f9ec0677cec8b41a49a4 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 11:16:22 -0400 Subject: [PATCH 17/76] fixes --- effectful/internals/unification.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 3679689f..14a73b47 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -78,16 +78,12 @@ def infer_return_type( ... TypeError: unbound type variables in return type """ - bound_sig.apply_defaults() sig: inspect.Signature = bound_sig.signature # validate that the function has a signature with well-formed type annotations if sig.return_annotation is inspect.Signature.empty: raise TypeError("Function must have a return type annotation") - if any(p.annotation is inspect.Signature.empty for p in sig.parameters.values()): - raise TypeError("All parameters must have type annotations") - result_fvs: set[typing.TypeVar] = freetypevars(sig.return_annotation) pattern_fvs: set[typing.TypeVar] = ( set.union(*(freetypevars(p.annotation) for p in sig.parameters.values())) @@ -133,7 +129,9 @@ def infer_return_type( # Apply substitutions to return type result_type = substitute(canonicalize(sig.return_annotation), subs) - if freetypevars(result_type): + if freetypevars(result_type) and not issubclass( + typing.get_origin(result_type), collections.abc.Callable + ): raise TypeError( "Return type cannot have free type variables after substitution" ) @@ -243,6 +241,7 @@ def unify( raise TypeError( f"Cannot unify {typ} with {subtyp} (already unified with {subs[typ]})" ) + return {**subs, **{typ: subtyp}} elif typing.get_args(typ) and typing.get_args(subtyp): typ_origin = typing.get_origin(typ) @@ -275,7 +274,9 @@ def unify( subs = unify(p_item, c_item, subs) return subs else: - if not issubclass(typ, subtyp): + subtyp = typing.get_origin(subtyp) or subtyp + typ = typing.get_origin(typ) or typ + if not issubclass(subtyp, typ): raise TypeError(f"Cannot unify {typ} with {subtyp}") return subs @@ -311,6 +312,9 @@ def canonicalize( for arg in typing.get_args(typ)[1:]: t = t | canonicalize(arg) return t + elif isinstance(typ, typing.TypeVar): + # TypeVars are already canonical + return typ elif isinstance(typ, typing._GenericAlias | types.GenericAlias): # type: ignore # Handle generic types origin = canonicalize(typing.get_origin(typ)) @@ -318,6 +322,8 @@ def canonicalize( return origin[tuple(canonicalize(a) for a in typing.get_args(typ))] elif isinstance(typ, collections.abc.Sequence): return tuple(canonicalize(item) for item in typ) + elif typ is inspect.Parameter.empty: + return canonicalize(typing.Any) elif typ is typing.Callable: return collections.abc.Callable elif typ is typing.Any: @@ -330,6 +336,8 @@ def canonicalize( return set elif typ is tuple: return tuple + elif not isinstance(typ, type) and typing.get_origin(typ) is None: + return type(typ) else: return typ From 4a1461fe5499ea6bef24ed6deda32c89c4205d95 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 11:26:06 -0400 Subject: [PATCH 18/76] fix error type --- effectful/internals/unification.py | 6 +++--- tests/test_internals_unification.py | 14 +++----------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 14a73b47..02c8bdf1 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -108,13 +108,13 @@ def infer_return_type( inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD, }: - raise NotImplementedError(f"Parameter '{name}' cannot be variadic") + raise TypeError(f"Parameter '{name}' cannot be variadic") if isinstance(bound_sig.arguments[name], collections.abc.Collection): - raise NotImplementedError(f"Parameter '{name}' cannot be a collection type") + raise TypeError(f"Parameter '{name}' cannot be a collection type") if freetypevars(bound_sig.arguments[name]): - raise NotImplementedError( + raise TypeError( f"Parameter '{name}' cannot have free type variables" ) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 2ea4a8be..a4ed670d 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -560,33 +560,27 @@ def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported @pytest.mark.parametrize( - "func,args,kwargs,error_type,error_pattern", + "func,args,kwargs", [ # Unbound type variable in return ( unbound_typevar_func, (int,), {}, - TypeError, - "unbound type variables in return type", ), # Missing annotations ( no_return_annotation, (int,), {}, - TypeError, - "Function must have a return type annotation", ), ( no_param_annotation, (int,), {}, - TypeError, - "All parameters must have type annotations", ), # Type mismatch - trying to unify incompatible types - (same_type_twice, (int, str), {}, TypeError, "Cannot unify"), + (same_type_twice, (int, str), {}), ], ids=str, ) @@ -594,12 +588,10 @@ def test_infer_return_type_failure( func: collections.abc.Callable, args: tuple, kwargs: dict, - error_type: type[Exception], - error_pattern: str, ): sig = inspect.signature(func) bound = sig.bind(*args, **kwargs) - with pytest.raises(error_type, match=error_pattern): + with pytest.raises(TypeError): infer_return_type(bound) From a27d811c7814992dcf5c4b633bedd345d20b6c1c Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 11:49:23 -0400 Subject: [PATCH 19/76] add _nested_type for sequences --- effectful/internals/unification.py | 39 +++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 02c8bdf1..00f52723 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -110,9 +110,6 @@ def infer_return_type( }: raise TypeError(f"Parameter '{name}' cannot be variadic") - if isinstance(bound_sig.arguments[name], collections.abc.Collection): - raise TypeError(f"Parameter '{name}' cannot be a collection type") - if freetypevars(bound_sig.arguments[name]): raise TypeError( f"Parameter '{name}' cannot have free type variables" @@ -320,8 +317,6 @@ def canonicalize( origin = canonicalize(typing.get_origin(typ)) assert origin is not None, "Type must have an origin" return origin[tuple(canonicalize(a) for a in typing.get_args(typ))] - elif isinstance(typ, collections.abc.Sequence): - return tuple(canonicalize(item) for item in typ) elif typ is inspect.Parameter.empty: return canonicalize(typing.Any) elif typ is typing.Callable: @@ -329,19 +324,41 @@ def canonicalize( elif typ is typing.Any: return object elif typ is list: - return list + return collections.abc.Sequence elif typ is dict: - return dict + return collections.abc.Mapping elif typ is set: - return set - elif typ is tuple: - return tuple + return collections.abc.Set elif not isinstance(typ, type) and typing.get_origin(typ) is None: - return type(typ) + return canonicalize(_nested_type(typ)) else: return typ +def _nested_type(value) -> type: + from effectful.ops.types import Interpretation, Operation + + if isinstance(value, Interpretation): + return Interpretation + elif isinstance(value, Operation): + return Operation + elif isinstance(value, tuple): + return tuple[tuple(_nested_type(item) for item in value)] + elif isinstance(value, collections.abc.Mapping): + k, v = next(iter(value.items())) + return collections.abc.Mapping[_nested_type(k), _nested_type(v)] + elif isinstance(value, collections.abc.Sequence) and len(value) > 0 and not isinstance(value, str | bytes): + return collections.abc.Sequence[_nested_type(next(iter(value)))] + elif isinstance(value, collections.abc.Set): + return collections.abc.Set[_nested_type(next(iter(value)))] + elif isinstance(value, collections.abc.Callable): + return type(value) + elif not isinstance(value, type) and typing.get_origin(value) is None: + return type(value) + else: + return value + + def freetypevars( typ: type | typing.TypeVar | types.GenericAlias | types.UnionType, ) -> set[typing.TypeVar]: From c5c09a32ce1e749a639909c10b95578f9417fa86 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 12:13:24 -0400 Subject: [PATCH 20/76] make _nested_type singledispatch extensible --- effectful/internals/unification.py | 68 +++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 00f52723..f8bb410b 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -1,4 +1,5 @@ import collections.abc +import functools import inspect import types import typing @@ -335,28 +336,65 @@ def canonicalize( return typ +@functools.singledispatch def _nested_type(value) -> type: - from effectful.ops.types import Interpretation, Operation + if not isinstance(value, type) and typing.get_origin(value) is None: + return type(value) + else: + return value + + +@_nested_type.register +def _(value: type | types.UnionType | types.GenericAlias | types.EllipsisType | types.NoneType) -> type: + return value + - if isinstance(value, Interpretation): +@_nested_type.register +def _(value: typing.TypeVar) -> type: + raise TypeError(f"TypeVars should not appear in values, but got {value}") + + +@_nested_type.register +def _(value: collections.abc.Callable) -> type: + return type(value) + + +@_nested_type.register +def _(value: collections.abc.Mapping) -> type: + from effectful.ops.types import Interpretation + + if isinstance(value, Interpretation): # type: ignore return Interpretation - elif isinstance(value, Operation): - return Operation - elif isinstance(value, tuple): - return tuple[tuple(_nested_type(item) for item in value)] - elif isinstance(value, collections.abc.Mapping): + elif len(value) == 0: + return type(value) + else: k, v = next(iter(value.items())) return collections.abc.Mapping[_nested_type(k), _nested_type(v)] - elif isinstance(value, collections.abc.Sequence) and len(value) > 0 and not isinstance(value, str | bytes): - return collections.abc.Sequence[_nested_type(next(iter(value)))] - elif isinstance(value, collections.abc.Set): - return collections.abc.Set[_nested_type(next(iter(value)))] - elif isinstance(value, collections.abc.Callable): + + +@_nested_type.register +def _(value: collections.abc.Set) -> type: + if len(value) == 0: return type(value) - elif not isinstance(value, type) and typing.get_origin(value) is None: + return collections.abc.Set[_nested_type(next(iter(value)))] + + +@_nested_type.register +def _(value: collections.abc.Sequence) -> type: + if len(value) == 0: return type(value) - else: - return value + return collections.abc.Sequence[_nested_type(next(iter(value)))] + + +@_nested_type.register +def _(value: tuple) -> type: + return tuple[tuple(_nested_type(item) for item in value)] + + +@_nested_type.register +def _(value: str | bytes) -> type: + # Handle str and bytes as their own types, not collections.abc.Sequence + return type(value) def freetypevars( From ce1243367e2bcb6a3186e2d6871582dde9cc25ac Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 12:16:16 -0400 Subject: [PATCH 21/76] term case --- effectful/internals/unification.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index f8bb410b..e61f70c9 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -338,7 +338,11 @@ def canonicalize( @functools.singledispatch def _nested_type(value) -> type: - if not isinstance(value, type) and typing.get_origin(value) is None: + from effectful.ops.types import Term + + if isinstance(value, Term): + raise TypeError(f"Terms should not appear in _nested_type, but got {value}") + elif not isinstance(value, type) and typing.get_origin(value) is None: return type(value) else: return value From f2492afc54ad0ab904c32a8d47ee28956ca5ff33 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 12:36:34 -0400 Subject: [PATCH 22/76] nit --- effectful/internals/unification.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index e61f70c9..9731812c 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -305,31 +305,31 @@ def canonicalize( """ if typing.get_origin(typ) is typing.Annotated: return canonicalize(typing.get_args(typ)[0]) + elif typ is inspect.Parameter.empty: + return canonicalize(typing.Any) elif typing.get_origin(typ) in {typing.Union, types.UnionType}: t = canonicalize(typing.get_args(typ)[0]) for arg in typing.get_args(typ)[1:]: t = t | canonicalize(arg) return t elif isinstance(typ, typing.TypeVar): - # TypeVars are already canonical return typ - elif isinstance(typ, typing._GenericAlias | types.GenericAlias): # type: ignore + elif isinstance(typ, typing._GenericAlias | types.GenericAlias) and typing.get_origin(typ) is not typ: # type: ignore # Handle generic types origin = canonicalize(typing.get_origin(typ)) assert origin is not None, "Type must have an origin" return origin[tuple(canonicalize(a) for a in typing.get_args(typ))] - elif typ is inspect.Parameter.empty: - return canonicalize(typing.Any) + # Handle legacy typing aliases like typing.Callable elif typ is typing.Callable: - return collections.abc.Callable + return canonicalize(collections.abc.Callable) elif typ is typing.Any: - return object - elif typ is list: - return collections.abc.Sequence - elif typ is dict: - return collections.abc.Mapping - elif typ is set: - return collections.abc.Set + return canonicalize(object) + elif typ is typing.List: + return canonicalize(list) + elif typ is typing.Dict: + return canonicalize(dict) + elif typ is typing.Set: + return canonicalize(set) elif not isinstance(typ, type) and typing.get_origin(typ) is None: return canonicalize(_nested_type(typ)) else: From 0cb5516e204c94d3968600d16c3a1aa473a9a6bf Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 13:12:46 -0400 Subject: [PATCH 23/76] docstrings and tests for canonicalize and nested_type --- effectful/internals/unification.py | 306 ++++++++++++++++++++++++---- tests/test_internals_unification.py | 291 ++++++++++++++++++++++++++ 2 files changed, 561 insertions(+), 36 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 9731812c..6588b448 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -38,6 +38,7 @@ def infer_return_type( Examples: >>> import inspect >>> import typing + >>> from effectful.internals.unification import infer_return_type >>> T = typing.TypeVar('T') >>> K = typing.TypeVar('K') >>> V = typing.TypeVar('V') @@ -285,28 +286,103 @@ def canonicalize( """ Return a canonical form of the given type expression. - This function normalizes the type by removing Annotated wrappers and - ensuring that generic types are represented in their canonical form. - It does not modify TypeVars or Union types, but ensures that generic - aliases are returned in a consistent format. + This function normalizes type expressions by: + - Removing Annotated wrappers to get the base type + - Converting legacy typing module aliases (e.g., typing.List) to modern forms (e.g., list) + - Preserving TypeVars unchanged + - Recursively canonicalizing type arguments in generic types + - Converting typing.Any to object + - Converting inspect.Parameter.empty to typing.Any (then to object) + - Handling Union types by creating canonical unions with | operator + - Converting non-type values to their types using _nested_type Args: - typ: The type expression to canonicalize. + typ: The type expression to canonicalize. Can be a plain type, TypeVar, + generic alias, union type, or even a value that needs type inference. Returns: - A canonicalized version of the input type expression. + A canonicalized version of the input type expression with consistent + representation and modern syntax. Examples: + >>> import typing + >>> import inspect + >>> import collections.abc + >>> from effectful.internals.unification import canonicalize >>> T = typing.TypeVar('T') - >>> canonicalize(typing.List[T]) + >>> K = typing.TypeVar('K') + >>> V = typing.TypeVar('V') + + # Legacy typing aliases are converted to modern forms + >>> canonicalize(typing.List[int]) + list[int] + >>> canonicalize(typing.Dict[str, int]) + dict[str, int] + >>> canonicalize(typing.Set[bool]) + set[bool] + >>> canonicalize(typing.Callable[[int], str]) + collections.abc.Callable[[int], str] + + # TypeVars are preserved unchanged + >>> canonicalize(T) + ~T + >>> canonicalize(list[T]) list[~T] - >>> canonicalize(typing.Annotated[int, "example"]) + + # Annotated types are unwrapped + >>> canonicalize(typing.Annotated[int, "metadata"]) + + >>> canonicalize(typing.Annotated[list[str], "doc string"]) + list[str] + + # Nested generic types are recursively canonicalized + >>> canonicalize(typing.List[typing.Dict[K, V]]) + list[dict[~K, ~V]] + >>> canonicalize(typing.Dict[str, typing.List[T]]) + dict[str, list[~T]] + + # Union types are canonicalized with | operator + >>> result = canonicalize(typing.Union[int, str]) + >>> result == int | str + True + >>> result = canonicalize(typing.Union[list[T], dict[K, V]]) + >>> result == list[T] | dict[K, V] + True + + # typing.Any becomes object + >>> canonicalize(typing.Any) + + + # inspect.Parameter.empty becomes object (via Any) + >>> canonicalize(inspect.Parameter.empty) + + + # Plain types pass through unchanged + >>> canonicalize(int) + >>> canonicalize(str) + + >>> canonicalize(list) + + + # Values are converted to their types via nested_type + >>> canonicalize([1, 2, 3]) + collections.abc.Sequence[int] + >>> canonicalize({"key": "value"}) + collections.abc.Mapping[str, str] + >>> canonicalize((1, "hello", 3.14)) + tuple[int, str, float] + + # Complex nested canonicalization + >>> canonicalize(typing.List[typing.Union[typing.Dict[str, T], None]]) + list[dict[str, ~T] | None] """ if typing.get_origin(typ) is typing.Annotated: return canonicalize(typing.get_args(typ)[0]) elif typ is inspect.Parameter.empty: return canonicalize(typing.Any) + elif typ is None: + return type(None) elif typing.get_origin(typ) in {typing.Union, types.UnionType}: t = canonicalize(typing.get_args(typ)[0]) for arg in typing.get_args(typ)[1:]: @@ -316,91 +392,249 @@ def canonicalize( return typ elif isinstance(typ, typing._GenericAlias | types.GenericAlias) and typing.get_origin(typ) is not typ: # type: ignore # Handle generic types - origin = canonicalize(typing.get_origin(typ)) - assert origin is not None, "Type must have an origin" - return origin[tuple(canonicalize(a) for a in typing.get_args(typ))] - # Handle legacy typing aliases like typing.Callable + origin = typing.get_origin(typ) + args = typing.get_args(typ) + + # Special handling for Callable types + if origin is collections.abc.Callable and args: + if len(args) == 2 and isinstance(args[0], (list, tuple)): + # Callable[[arg1, arg2, ...], return_type] format + param_list = [canonicalize(a) for a in args[0]] + return_type = canonicalize(args[1]) + return collections.abc.Callable[[*param_list], return_type] + else: + # Handle other Callable formats + return origin[tuple(canonicalize(a) for a in args)] + else: + # Regular generic types + canonical_origin = canonicalize(origin) + return canonical_origin[tuple(canonicalize(a) for a in args)] + # Handle legacy typing aliases + elif hasattr(typing, 'List') and typ is getattr(typing, 'List', None): + return list + elif hasattr(typing, 'Dict') and typ is getattr(typing, 'Dict', None): + return dict + elif hasattr(typing, 'Set') and typ is getattr(typing, 'Set', None): + return set elif typ is typing.Callable: - return canonicalize(collections.abc.Callable) + return collections.abc.Callable elif typ is typing.Any: - return canonicalize(object) - elif typ is typing.List: - return canonicalize(list) - elif typ is typing.Dict: - return canonicalize(dict) - elif typ is typing.Set: - return canonicalize(set) + return object elif not isinstance(typ, type) and typing.get_origin(typ) is None: - return canonicalize(_nested_type(typ)) + return canonicalize(nested_type(typ)) else: return typ @functools.singledispatch -def _nested_type(value) -> type: +def nested_type(value) -> type: + """ + Infer the type of a value, handling nested collections with generic parameters. + + This function is a singledispatch generic function that determines the type + of a given value. For collections (mappings, sequences, sets), it recursively + infers the types of contained elements to produce a properly parameterized + generic type. For example, a list [1, 2, 3] becomes Sequence[int]. + + The function handles: + - Basic types and type annotations (passed through unchanged) + - Collections with recursive type inference for elements + - Special cases like str/bytes (treated as types, not sequences) + - Tuples (preserving exact element types) + - Empty collections (returning the collection's type without parameters) + + This is primarily used by canonicalize() to handle cases where values + are provided instead of type annotations. + + Args: + value: Any value whose type needs to be inferred. Can be a type, + a value instance, or a collection containing other values. + + Returns: + The inferred type, potentially with generic parameters for collections. + + Raises: + TypeError: If the value is a TypeVar (TypeVars shouldn't appear in values) + or if the value is a Term from effectful.ops.types. + + Examples: + >>> import collections.abc + >>> import typing + >>> from effectful.internals.unification import nested_type + + # Basic types are returned as their type + >>> nested_type(42) + + >>> nested_type("hello") + + >>> nested_type(3.14) + + >>> nested_type(True) + + + # Type objects pass through unchanged + >>> nested_type(int) + + >>> nested_type(str) + + >>> nested_type(list) + + + # Empty collections return their base type + >>> nested_type([]) + + >>> nested_type({}) + + >>> nested_type(set()) + + + # Sequences become Sequence[element_type] + >>> nested_type([1, 2, 3]) + collections.abc.Sequence[int] + >>> nested_type(["a", "b", "c"]) + collections.abc.Sequence[str] + + # Tuples preserve exact structure + >>> nested_type((1, "hello", 3.14)) + tuple[int, str, float] + >>> nested_type(()) + + >>> nested_type((1,)) + tuple[int] + + # Sets become Set[element_type] + >>> nested_type({1, 2, 3}) + collections.abc.Set[int] + >>> nested_type({"a", "b"}) + collections.abc.Set[str] + + # Mappings become Mapping[key_type, value_type] + >>> nested_type({"key": "value"}) + collections.abc.Mapping[str, str] + >>> nested_type({1: "one", 2: "two"}) + collections.abc.Mapping[int, str] + + # Nested collections work recursively + >>> nested_type([{1: "one"}, {2: "two"}]) + collections.abc.Sequence[collections.abc.Mapping[int, str]] + >>> nested_type({"key": [1, 2, 3]}) + collections.abc.Mapping[str, collections.abc.Sequence[int]] + + # Strings and bytes are NOT treated as sequences + >>> nested_type("hello") + + >>> nested_type(b"bytes") + + + # Functions/callables return their type + >>> def f(): pass + >>> nested_type(f) + + >>> nested_type(lambda x: x) + + + # TypeVars raise an error + >>> T = typing.TypeVar('T') + >>> nested_type(T) # doctest: +ELLIPSIS + Traceback (most recent call last): + ... + TypeError: TypeVars should not appear in values, but got ~T + + # None has its own type + >>> nested_type(None) + + + # Generic aliases and union types pass through + >>> nested_type(list[int]) + list[int] + >>> nested_type(int | str) + int | str + """ from effectful.ops.types import Term if isinstance(value, Term): - raise TypeError(f"Terms should not appear in _nested_type, but got {value}") + raise TypeError(f"Terms should not appear in nested_type, but got {value}") + elif value is None: + return type(None) elif not isinstance(value, type) and typing.get_origin(value) is None: return type(value) else: return value -@_nested_type.register +@nested_type.register def _(value: type | types.UnionType | types.GenericAlias | types.EllipsisType | types.NoneType) -> type: return value -@_nested_type.register +@nested_type.register +def _(value: typing._GenericAlias) -> type: # type: ignore + # Handle typing module generic aliases + return value + + +@nested_type.register +def _(value: types.NoneType) -> type: + # Handle None specially + return type(None) + + +@nested_type.register def _(value: typing.TypeVar) -> type: raise TypeError(f"TypeVars should not appear in values, but got {value}") -@_nested_type.register +@nested_type.register def _(value: collections.abc.Callable) -> type: return type(value) -@_nested_type.register +@nested_type.register def _(value: collections.abc.Mapping) -> type: from effectful.ops.types import Interpretation - if isinstance(value, Interpretation): # type: ignore + if type(value) is Interpretation: # More specific check return Interpretation elif len(value) == 0: return type(value) else: k, v = next(iter(value.items())) - return collections.abc.Mapping[_nested_type(k), _nested_type(v)] + return collections.abc.Mapping[nested_type(k), nested_type(v)] -@_nested_type.register +@nested_type.register def _(value: collections.abc.Set) -> type: if len(value) == 0: return type(value) - return collections.abc.Set[_nested_type(next(iter(value)))] + return collections.abc.Set[nested_type(next(iter(value)))] -@_nested_type.register +@nested_type.register def _(value: collections.abc.Sequence) -> type: if len(value) == 0: return type(value) - return collections.abc.Sequence[_nested_type(next(iter(value)))] + return collections.abc.Sequence[nested_type(next(iter(value)))] -@_nested_type.register +@nested_type.register def _(value: tuple) -> type: - return tuple[tuple(_nested_type(item) for item in value)] + if len(value) == 0: + return tuple + return tuple[tuple(nested_type(item) for item in value)] -@_nested_type.register +@nested_type.register def _(value: str | bytes) -> type: # Handle str and bytes as their own types, not collections.abc.Sequence return type(value) +@nested_type.register(range) +def _(value: range) -> type: + # Handle range as its own type, not as a sequence + return type(value) + + def freetypevars( typ: type | typing.TypeVar | types.GenericAlias | types.UnionType, ) -> set[typing.TypeVar]: diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index a4ed670d..dd43608d 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -5,8 +5,10 @@ import pytest from effectful.internals.unification import ( + canonicalize, freetypevars, infer_return_type, + nested_type, substitute, unify, ) @@ -610,3 +612,292 @@ def test_infer_return_type_variadic_kwargs(): bound = sig.bind(x=int) result = infer_return_type(bound) assert result == int + + +@pytest.mark.parametrize( + "typ,expected", + [ + # Plain types pass through unchanged + (int, int), + (str, str), + (float, float), + (bool, bool), + (list, list), + (dict, dict), + (set, set), + (tuple, tuple), + # TypeVars are preserved + (T, T), + (K, K), + (V, V), + # Legacy typing aliases are converted to modern forms + (list, list), + (dict, dict), + (set, set), + (list[int], list[int]), + (dict[str, int], dict[str, int]), + (set[bool], set[bool]), + (list[T], list[T]), + (dict[K, V], dict[K, V]), + # typing.Callable becomes collections.abc.Callable + (typing.Callable, collections.abc.Callable), + (typing.Callable[[int], str], collections.abc.Callable[[int], str]), + (typing.Callable[[T, U], V], collections.abc.Callable[[T, U], V]), + # typing.Any becomes object + (typing.Any, object), + # inspect.Parameter.empty becomes object (via Any) + (inspect.Parameter.empty, object), + # Annotated types are unwrapped + (typing.Annotated[int, "metadata"], int), + (typing.Annotated[str, "doc string"], str), + (typing.Annotated[list[T], "annotation"], list[T]), + (typing.Annotated[dict[K, V], "complex", "multi"], dict[K, V]), + # Nested Annotated unwrapping + (typing.Annotated[typing.Annotated[int, "inner"], "outer"], int), + # Modern generic aliases pass through with canonicalized args + (list[int], list[int]), + (dict[str, bool], dict[str, bool]), + (set[float], set[float]), + (list[list[int]], list[list[int]]), # Nested legacy conversion + (dict[str, dict[K, V]], dict[str, dict[K, V]]), + # Union types are canonicalized with | operator + (typing.Union[int, str], int | str), + (typing.Union[T, int], T | int), + (typing.Union[list[T], dict[K, V]], list[T] | dict[K, V]), + (typing.Union[int, str, bool], int | str | bool), + # Nested unions + (typing.Union[list[int], None], list[int] | type(None)), + (typing.Union[dict[K, V], set[T]], dict[K, V] | set[T]), + # Complex nested canonicalization + (list[dict[str, T] | None], list[dict[str, T] | type(None)]), + (dict[K, list[V | None]], dict[K, list[V | type(None)]]), + # Already canonical forms pass through + (list[T], list[T]), + (dict[K, V], dict[K, V]), + (int | str, int | str), + (list[int] | None, list[int] | type(None)), + # None type handling + (type(None), type(None)), + (typing.Union[int, None], int | type(None)), + # Ellipsis type + (type(...), type(...)), + # Generic aliases with multiple levels of nesting + (list[dict[set[T], list[V]]], list[dict[set[T], list[V]]]), + # Callable with nested canonicalization + (typing.Callable[[list[T]], dict[K, V]], collections.abc.Callable[[list[T]], dict[K, V]]), + # Values that will use _nested_type (basic values) + (42, int), + ("hello", str), + (3.14, float), + (True, bool), + (None, type(None)), + # Collections that will use _nested_type + ([1, 2, 3], collections.abc.Sequence[int]), + (["a", "b"], collections.abc.Sequence[str]), + ({"key": "value"}, collections.abc.Mapping[str, str]), + ({1: "one", 2: "two"}, collections.abc.Mapping[int, str]), + ({1, 2, 3}, collections.abc.Set[int]), + ({"x", "y"}, collections.abc.Set[str]), + # Tuples get special handling + ((1, "hello", 3.14), tuple[int, str, float]), + ((1,), tuple[int]), + ((), tuple), + # Empty collections + ([], list), + ({}, dict), + (set(), set), + # Nested collections via _nested_type + ([{1: "one"}, {2: "two"}], collections.abc.Sequence[collections.abc.Mapping[int, str]]), + ({"key": [1, 2, 3]}, collections.abc.Mapping[str, collections.abc.Sequence[int]]), + ([(1, "a"), (2, "b")], collections.abc.Sequence[tuple[int, str]]), + # Functions/callables + (lambda x: x, type(lambda x: x)), + # bytes handling + (b"hello", bytes), + ], + ids=str, +) +def test_canonicalize(typ, expected): + result = canonicalize(typ) + assert result == expected + + +@pytest.mark.parametrize( + "value,expected", + [ + # Basic value types + (42, int), + (0, int), + (-5, int), + ("hello", str), + ("", str), + (3.14, float), + (0.0, float), + (True, bool), + (False, bool), + (None, type(None)), + (b"bytes", bytes), + (b"", bytes), + # Type objects pass through + (int, int), + (str, str), + (float, float), + (bool, bool), + (list, list), + (dict, dict), + (set, set), + (tuple, tuple), + (type(None), type(None)), + (type(...), type(...)), + # Generic aliases pass through + (list[int], list[int]), + (dict[str, int], dict[str, int]), + (set[bool], set[bool]), + (tuple[int, str], tuple[int, str]), + (int | str, int | str), + (list[T], list[T]), + (dict[K, V], dict[K, V]), + # Union types pass through + (typing.Union[int, str], typing.Union[int, str]), + (int | str, int | str), + # Empty collections + ([], list), + ({}, dict), + (set(), set), + ((), tuple), + # Lists/sequences with single type + ([1, 2, 3], collections.abc.Sequence[int]), + ([1], collections.abc.Sequence[int]), + (["a", "b", "c"], collections.abc.Sequence[str]), + ([True, False], collections.abc.Sequence[bool]), + ([1.1, 2.2], collections.abc.Sequence[float]), + # Sets with elements + ({1, 2, 3}, collections.abc.Set[int]), + ({1}, collections.abc.Set[int]), + ({"a", "b"}, collections.abc.Set[str]), + ({True, False}, collections.abc.Set[bool]), + # Dicts/mappings + ({"key": "value"}, collections.abc.Mapping[str, str]), + ({1: "one", 2: "two"}, collections.abc.Mapping[int, str]), + ({"a": 1, "b": 2}, collections.abc.Mapping[str, int]), + ({True: 1.0, False: 2.0}, collections.abc.Mapping[bool, float]), + # Tuples preserve exact structure + ((1, "hello", 3.14), tuple[int, str, float]), + ((1,), tuple[int]), + ((1, 2), tuple[int, int]), + (("a", "b", "c"), tuple[str, str, str]), + ((True, 1, "x", 3.14), tuple[bool, int, str, float]), + # Nested collections + ([[1, 2], [3, 4]], collections.abc.Sequence[collections.abc.Sequence[int]]), + ([{1, 2}, {3, 4}], collections.abc.Sequence[collections.abc.Set[int]]), + ([{"a": 1}, {"b": 2}], collections.abc.Sequence[collections.abc.Mapping[str, int]]), + ({"key": [1, 2, 3]}, collections.abc.Mapping[str, collections.abc.Sequence[int]]), + ({"a": {1, 2}, "b": {3, 4}}, collections.abc.Mapping[str, collections.abc.Set[int]]), + ({1: {"x": True}, 2: {"y": False}}, collections.abc.Mapping[int, collections.abc.Mapping[str, bool]]), + # Tuples in collections + ([(1, "a"), (2, "b")], collections.abc.Sequence[tuple[int, str]]), + ({(1, 2), (3, 4)}, collections.abc.Set[tuple[int, int]]), + ({1: (True, "x"), 2: (False, "y")}, collections.abc.Mapping[int, tuple[bool, str]]), + # Functions/callables + (lambda x: x, type(lambda x: x)), + (print, type(print)), + (len, type(len)), + # Complex nested structures + ([[[1]]], collections.abc.Sequence[collections.abc.Sequence[collections.abc.Sequence[int]]]), + ({"a": {"b": {"c": 1}}}, collections.abc.Mapping[str, collections.abc.Mapping[str, collections.abc.Mapping[str, int]]]), + # Special string/bytes handling (NOT treated as sequences) + ("hello", str), + (b"world", bytes), + # Other built-in types + (range(5), type(range(5))), + (slice(1, 10), type(slice(1, 10))), + ], + ids=str, +) +def test_nested_type(value, expected): + result = nested_type(value) + assert result == expected + + +def test_nested_type_typevar_error(): + """Test that TypeVars raise TypeError in nested_type""" + with pytest.raises(TypeError, match="TypeVars should not appear in values"): + nested_type(T) + + with pytest.raises(TypeError, match="TypeVars should not appear in values"): + nested_type(K) + + with pytest.raises(TypeError, match="TypeVars should not appear in values"): + nested_type(V) + + +def test_nested_type_term_error(): + """Test that Terms raise TypeError in nested_type""" + # We can't import Term here without creating a circular dependency, + # so we'll create a mock object that would trigger the isinstance check + from unittest.mock import Mock + + from effectful.ops.types import Term + + mock_term = Mock(spec=Term) + with pytest.raises(TypeError, match="Terms should not appear in nested_type"): + nested_type(mock_term) + + +def test_nested_type_interpretation_special_case(): + """Test that Interpretation type is handled specially in mapping dispatch""" + # This tests the special case in the Mapping dispatch for Interpretation + from effectful.ops.types import Interpretation + + # Since Interpretation is a Protocol, we can't instantiate it directly. + # Instead, create a concrete implementation + class ConcreteInterpretation(dict): + """A concrete implementation of Interpretation for testing""" + pass + + # Make it look like an Interpretation type for the type check + ConcreteInterpretation.__name__ = 'Interpretation' + ConcreteInterpretation.__module__ = Interpretation.__module__ + + # Create instance and patch the type temporarily + interp = ConcreteInterpretation() + original_type = type(interp) + + # The _nested_type checks `type(value) is Interpretation` + # Since we can't change the type of an instance, we'll test the behavior + # by verifying that our empty dict doesn't trigger the Interpretation path + result = nested_type({}) + assert result == dict # Empty dict should return dict, not Interpretation + + +def test_canonicalize_with_nested_type_integration(): + """Test that canonicalize properly uses _nested_type for values""" + # Test various values go through _nested_type in canonicalize + assert canonicalize([1, 2, 3]) == collections.abc.Sequence[int] + assert canonicalize({"a": 1}) == collections.abc.Mapping[str, int] + assert canonicalize({1, 2}) == collections.abc.Set[int] + assert canonicalize((1, "x")) == tuple[int, str] + + # Nested case + nested = [{"a": (1, 2)}, {"b": (3, 4)}] + expected = collections.abc.Sequence[collections.abc.Mapping[str, tuple[int, int]]] + assert canonicalize(nested) == expected + + +def test_canonicalize_preserves_modern_syntax(): + """Test that already-modern syntax is preserved""" + # These should pass through unchanged + assert canonicalize(list[int]) == list[int] + assert canonicalize(dict[str, int]) == dict[str, int] + assert canonicalize(int | str) == int | str + assert canonicalize(list[T] | None) == list[T] | type(None) + + +def test_canonicalize_union_ordering(): + """Test union type canonicalization handles ordering consistently""" + # The | operator may reorder types, but the result should be equivalent + result1 = canonicalize(typing.Union[int, str]) + result2 = canonicalize(typing.Union[str, int]) + # Both should be some form of int | str or str | int + assert result1 == int | str or result1 == str | int + assert result2 == int | str or result2 == str | int From 87dd5c8ef224b07bd0c799ce855230c1088b67b1 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 13:14:07 -0400 Subject: [PATCH 24/76] no notimmplentederror --- effectful/internals/unification.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 6588b448..dd95058f 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -30,10 +30,9 @@ def infer_return_type( Raises: TypeError: If the function lacks required type annotations, has unbound - type variables in the return type, or if unification fails - NotImplementedError: If the function uses variadic parameters (*args, **kwargs), - collection types as parameters, or parameters with - free type variables + type variables in the return type, if unification fails, + if the function uses variadic parameters (*args, **kwargs), + or if parameters have free type variables. Examples: >>> import inspect From 841fc498a398b1c606e96612c9010bab45ccb891 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 13:31:55 -0400 Subject: [PATCH 25/76] variadic params --- effectful/internals/unification.py | 35 ++++++++++++++++---------- tests/test_internals_unification.py | 38 ++++++++++------------------- 2 files changed, 35 insertions(+), 38 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index dd95058f..24bf9c83 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -103,14 +103,8 @@ def infer_return_type( if not result_fvs <= pattern_fvs: raise TypeError("unbound type variables in return type") - # Check for variadic parameters and collections - not implemented yet + # Check for type variables in concrete arguments - not implemented yet for name, param in sig.parameters.items(): - if param.kind in { - inspect.Parameter.VAR_POSITIONAL, - inspect.Parameter.VAR_KEYWORD, - }: - raise TypeError(f"Parameter '{name}' cannot be variadic") - if freetypevars(bound_sig.arguments[name]): raise TypeError( f"Parameter '{name}' cannot have free type variables" @@ -118,12 +112,27 @@ def infer_return_type( # Build substitution map subs: collections.abc.Mapping[typing.TypeVar, type] = {} - for name in sig.parameters: - subs = unify( - canonicalize(sig.parameters[name].annotation), - canonicalize(bound_sig.arguments[name]), - subs, - ) + for name, param in sig.parameters.items(): + if param.kind is inspect.Parameter.VAR_POSITIONAL: + for arg in bound_sig.arguments[name]: + subs = unify( + canonicalize(param.annotation), + canonicalize(arg), + subs, + ) + elif param.kind is inspect.Parameter.VAR_KEYWORD: + for arg in bound_sig.arguments[name].values(): + subs = unify( + canonicalize(param.annotation), + canonicalize(arg), + subs, + ) + else: + subs = unify( + canonicalize(param.annotation), + canonicalize(bound_sig.arguments[name]), + subs, + ) # Apply substitutions to return type result_type = substitute(canonicalize(sig.return_annotation), subs) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index dd43608d..bc6ddd1d 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -453,6 +453,14 @@ def nested_generic(x: T) -> dict[str, list[T]]: return {"items": [x]} +def variadic_args_func(*args: T) -> T: # Variadic args not supported + return args[0] + + +def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported + return next(iter(kwargs.values())) + + @pytest.mark.parametrize( "func,args,kwargs,expected_return_type", [ @@ -525,6 +533,11 @@ def nested_generic(x: T) -> dict[str, list[T]]: {"a": bool, "b": list[bool], "c": dict[int, str]}, tuple[bool, int, str], ), + # variadic args and kwargs + (variadic_args_func, (int,), {}, int), + (variadic_args_func, (int, int), {}, int), + (variadic_kwargs_func, (), {"x": int}, int), + (variadic_kwargs_func, (), {"x": int, "y": int}, int), ], ids=str, ) @@ -553,14 +566,6 @@ def no_param_annotation(x) -> T: # No parameter annotation return x # type: ignore -def variadic_args_func(*args: T) -> T: # Variadic args not supported - return args[0] - - -def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported - return next(iter(kwargs.values())) - - @pytest.mark.parametrize( "func,args,kwargs", [ @@ -597,23 +602,6 @@ def test_infer_return_type_failure( infer_return_type(bound) -# Variadic functions - not implemented yet, marked as expected failures -@pytest.mark.xfail(reason="Variadic args not implemented") -def test_infer_return_type_variadic_args(): - sig = inspect.signature(variadic_args_func) - bound = sig.bind(int) - result = infer_return_type(bound) - assert result == int - - -@pytest.mark.xfail(reason="Variadic kwargs not implemented") -def test_infer_return_type_variadic_kwargs(): - sig = inspect.signature(variadic_kwargs_func) - bound = sig.bind(x=int) - result = infer_return_type(bound) - assert result == int - - @pytest.mark.parametrize( "typ,expected", [ From c1d102a4b2b3c9ffa1affdb147fba426838c02aa Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 13:40:00 -0400 Subject: [PATCH 26/76] remove dead code paths in unify --- effectful/internals/unification.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 24bf9c83..68717ff8 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -237,13 +237,7 @@ def unify( >>> unify((T, V), (int, str), {}) {~T: , ~V: } """ - if typing.get_origin(typ) is typing.Annotated: - # Handle Annotated types by extracting the base type - return unify(typing.get_args(typ)[0], subtyp, subs) - elif typing.get_origin(subtyp) is typing.Annotated: - # Handle Annotated types by extracting the base type - return unify(typ, typing.get_args(subtyp)[0], subs) - elif isinstance(typ, typing.TypeVar): + if isinstance(typ, typing.TypeVar): if typ in subs and subs[typ] != subtyp: raise TypeError( f"Cannot unify {typ} with {subtyp} (already unified with {subs[typ]})" @@ -264,14 +258,6 @@ def unify( if typ_origin != subtyp_origin: raise TypeError(f"Cannot unify {typ} with {subtyp}") return unify(typing.get_args(typ), typing.get_args(subtyp), subs) - elif isinstance(typ, collections.abc.Mapping) and isinstance( - subtyp, collections.abc.Mapping - ): - if typ.keys() != subtyp.keys(): - raise TypeError(f"Cannot unify {typ} with {subtyp}") - for key in typ: - subs = unify(typ[key], subtyp[key], subs) - return subs elif isinstance(typ, collections.abc.Sequence) and isinstance( subtyp, collections.abc.Sequence ): From 25eaa571f9c747e4b94c7d6f20019671719f0bcd Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 13:52:30 -0400 Subject: [PATCH 27/76] simplify unify --- effectful/internals/unification.py | 24 +++++++----------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 68717ff8..c7668400 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -105,7 +105,7 @@ def infer_return_type( # Check for type variables in concrete arguments - not implemented yet for name, param in sig.parameters.items(): - if freetypevars(bound_sig.arguments[name]): + if freetypevars(bound_sig.arguments.get(name, None)): raise TypeError( f"Parameter '{name}' cannot have free type variables" ) @@ -244,23 +244,13 @@ def unify( ) return {**subs, **{typ: subtyp}} + elif isinstance(typ, types.UnionType) or isinstance(subtyp, types.UnionType): + # TODO handle UnionType properly + return unify(typing.get_args(typ), typing.get_args(subtyp), subs) elif typing.get_args(typ) and typing.get_args(subtyp): - typ_origin = typing.get_origin(typ) - subtyp_origin = typing.get_origin(subtyp) - - # Handle Union types - both typing.Union and types.UnionType are compatible - if typ_origin in (typing.Union, types.UnionType) and subtyp_origin in ( - typing.Union, - types.UnionType, - ): - return unify(typing.get_args(typ), typing.get_args(subtyp), subs) - - if typ_origin != subtyp_origin: - raise TypeError(f"Cannot unify {typ} with {subtyp}") + subs = unify(typing.get_origin(typ), typing.get_origin(subtyp), subs) return unify(typing.get_args(typ), typing.get_args(subtyp), subs) - elif isinstance(typ, collections.abc.Sequence) and isinstance( - subtyp, collections.abc.Sequence - ): + elif isinstance(typ, list | tuple) and isinstance(subtyp, list | tuple): if len(typ) != len(subtyp): raise TypeError(f"Cannot unify {typ} with {subtyp}") for p_item, c_item in zip(typ, subtyp): @@ -630,7 +620,7 @@ def _(value: range) -> type: def freetypevars( - typ: type | typing.TypeVar | types.GenericAlias | types.UnionType, + typ: type | typing.TypeVar | types.GenericAlias | types.UnionType | types.NoneType, ) -> set[typing.TypeVar]: """ Return a set of free type variables in the given type expression. From dfd350714f308be1e5d4742c218063c683c4d112 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 14:00:22 -0400 Subject: [PATCH 28/76] union --- effectful/internals/unification.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index c7668400..38b13fd8 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -244,7 +244,8 @@ def unify( ) return {**subs, **{typ: subtyp}} - elif isinstance(typ, types.UnionType) or isinstance(subtyp, types.UnionType): + elif typing.get_origin(typ) in {typing.Union, types.UnionType} or \ + typing.get_origin(subtyp) in {typing.Union, types.UnionType}: # TODO handle UnionType properly return unify(typing.get_args(typ), typing.get_args(subtyp), subs) elif typing.get_args(typ) and typing.get_args(subtyp): From 6d67944425661a8e06cbc39c6f4554ff9beb2499 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 15:52:55 -0400 Subject: [PATCH 29/76] tweak --- effectful/internals/unification.py | 53 ++++++++----------- tests/test_internals_unification.py | 82 ++++++----------------------- 2 files changed, 39 insertions(+), 96 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 38b13fd8..95f6c9f1 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -221,13 +221,13 @@ def unify( >>> unify(list[T], dict[str, int], {}) # doctest: +ELLIPSIS Traceback (most recent call last): ... - TypeError: Cannot unify list[~T] with dict[str, int] + TypeError: Cannot unify ... >>> # Failed unification - conflicting TypeVar binding >>> unify(T, str, {T: int}) # doctest: +ELLIPSIS Traceback (most recent call last): ... - TypeError: Cannot unify ~T with (already unified with ) + TypeError: Cannot unify ... >>> # Callable type unification >>> unify(collections.abc.Callable[[T], V], collections.abc.Callable[[int], str], {}) @@ -238,11 +238,8 @@ def unify( {~T: , ~V: } """ if isinstance(typ, typing.TypeVar): - if typ in subs and subs[typ] != subtyp: - raise TypeError( - f"Cannot unify {typ} with {subtyp} (already unified with {subs[typ]})" - ) - + if typ in subs: + subs = unify(subs[typ], subtyp, subs) return {**subs, **{typ: subtyp}} elif typing.get_origin(typ) in {typing.Union, types.UnionType} or \ typing.get_origin(subtyp) in {typing.Union, types.UnionType}: @@ -251,18 +248,14 @@ def unify( elif typing.get_args(typ) and typing.get_args(subtyp): subs = unify(typing.get_origin(typ), typing.get_origin(subtyp), subs) return unify(typing.get_args(typ), typing.get_args(subtyp), subs) - elif isinstance(typ, list | tuple) and isinstance(subtyp, list | tuple): - if len(typ) != len(subtyp): - raise TypeError(f"Cannot unify {typ} with {subtyp}") + elif isinstance(typ, list | tuple) and isinstance(subtyp, list | tuple) and len(typ) == len(subtyp): for p_item, c_item in zip(typ, subtyp): subs = unify(p_item, c_item, subs) return subs - else: - subtyp = typing.get_origin(subtyp) or subtyp - typ = typing.get_origin(typ) or typ - if not issubclass(subtyp, typ): - raise TypeError(f"Cannot unify {typ} with {subtyp}") + elif issubclass(typing.get_origin(subtyp) or subtyp, typing.get_origin(typ) or typ): return subs + else: + raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") def canonicalize( @@ -375,25 +368,25 @@ def canonicalize( return t elif isinstance(typ, typing.TypeVar): return typ + elif typing.get_origin(typ) is collections.abc.Callable: + origin, args = typing.get_origin(typ), typing.get_args(typ) + if not args: + return origin + elif len(args) == 2 and isinstance(args[0], (list, tuple)): + # Callable[[arg1, arg2, ...], return_type] format + param_list = [canonicalize(a) for a in args[0]] + return_type = canonicalize(args[1]) + return origin[[*param_list], return_type] + else: + # Handle other Callable formats + return origin[tuple(canonicalize(a) for a in args)] elif isinstance(typ, typing._GenericAlias | types.GenericAlias) and typing.get_origin(typ) is not typ: # type: ignore # Handle generic types origin = typing.get_origin(typ) args = typing.get_args(typ) - - # Special handling for Callable types - if origin is collections.abc.Callable and args: - if len(args) == 2 and isinstance(args[0], (list, tuple)): - # Callable[[arg1, arg2, ...], return_type] format - param_list = [canonicalize(a) for a in args[0]] - return_type = canonicalize(args[1]) - return collections.abc.Callable[[*param_list], return_type] - else: - # Handle other Callable formats - return origin[tuple(canonicalize(a) for a in args)] - else: - # Regular generic types - canonical_origin = canonicalize(origin) - return canonical_origin[tuple(canonicalize(a) for a in args)] + # Regular generic types + canonical_origin = canonicalize(origin) + return canonical_origin[tuple(canonicalize(a) for a in args)] # Handle legacy typing aliases elif hasattr(typing, 'List') and typ is getattr(typing, 'List', None): return list diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index bc6ddd1d..181f643f 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -314,89 +314,39 @@ def test_unify_success( @pytest.mark.parametrize( - "typ,subtyp,initial_subs,error_pattern", + "typ,subtyp", [ # Incompatible types - ( - list[T], - dict[str, int], - {}, - "Cannot unify list\\[~T\\] with dict\\[str, int\\]", - ), - (int, str, {}, "Cannot unify with "), - (list[int], list[str], {}, "Cannot unify with "), - # Conflicting TypeVar bindings - ( - T, - str, - {T: int}, - "Cannot unify ~T with \\(already unified with \\)", - ), - ( - list[T], - list[str], - {T: int}, - "Cannot unify ~T with \\(already unified with \\)", - ), + (list[T], dict[str, int]), + (int, str), + (list[int], list[str]), # Mismatched generic types - (list[T], set[int], {}, "Cannot unify list\\[~T\\] with set\\[int\\]"), - (dict[K, V], list[int], {}, "Cannot unify dict\\[~K, ~V\\] with list\\[int\\]"), + (list[T], set[int]), + (dict[K, V], list[int]), # Same TypeVar with different values - ( - dict[T, T], - dict[int, str], - {}, - "Cannot unify ~T with \\(already unified with \\)", - ), - ( - tuple[T, T], - tuple[int, str], - {}, - "Cannot unify ~T with \\(already unified with \\)", - ), + (dict[T, T], dict[int, str]), + (tuple[T, T], tuple[int, str]), # Mismatched arities - (tuple[T, U], tuple[int, str, bool], {}, "Cannot unify"), + (tuple[T, U], tuple[int, str, bool]), ( collections.abc.Callable[[T], V], collections.abc.Callable[[int, str], bool], - {}, - "Cannot unify", ), # Sequence length mismatch - ((T, V), (int,), {}, "Cannot unify"), - ([T, V], [int, str, bool], {}, "Cannot unify"), - # Union type failures - element-wise unification failures - ( - T | V, - int | str, - {T: float}, - "Cannot unify ~T with ", - ), # TypeVar conflict - ( - T | int, - V | str, - {}, - "Cannot unify with ", - ), # Concrete type mismatch - ( - T | int, - V | str, - {}, - "Cannot unify with ", - ), # typing.Union mismatch - # Union with different arities - (T | V, int | str | bool, {}, "Cannot unify"), # Different union sizes + ((T, V), (int,)), + ([T, V], [int, str, bool]), + # Union failure cases + (T | int, V | str), # typing.Union mismatch + (T | V, int | str | bool), # Different union sizes ], ids=str, ) def test_unify_failure( typ: type, subtyp: type, - initial_subs: typing.Mapping[typing.TypeVar, type], - error_pattern: str, ): - with pytest.raises(TypeError, match=error_pattern): - unify(typ, subtyp, initial_subs) + with pytest.raises(TypeError): + unify(typ, subtyp, {}) # Test functions with various type patterns From 652061c9b918ce7882b566e8f6feae09aa1bd3ed Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 4 Jul 2025 16:17:45 -0400 Subject: [PATCH 30/76] cleanup --- effectful/internals/unification.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 95f6c9f1..f0928ce6 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -86,16 +86,8 @@ def infer_return_type( raise TypeError("Function must have a return type annotation") result_fvs: set[typing.TypeVar] = freetypevars(sig.return_annotation) - pattern_fvs: set[typing.TypeVar] = ( - set.union(*(freetypevars(p.annotation) for p in sig.parameters.values())) - if sig.parameters - else set() - ) - concrete_fvs: set[typing.TypeVar] = ( - set.union(*(freetypevars(arg) for arg in bound_sig.arguments.values())) - if bound_sig.arguments - else set() - ) + pattern_fvs: set[typing.TypeVar] = set().union(*(freetypevars(p.annotation) for p in sig.parameters.values())) + concrete_fvs: set[typing.TypeVar] = set().union(*(freetypevars(arg) for arg in bound_sig.arguments.values())) if (result_fvs | pattern_fvs) & concrete_fvs: raise TypeError( "Cannot unify free type variables in pattern and concrete types" @@ -103,13 +95,6 @@ def infer_return_type( if not result_fvs <= pattern_fvs: raise TypeError("unbound type variables in return type") - # Check for type variables in concrete arguments - not implemented yet - for name, param in sig.parameters.items(): - if freetypevars(bound_sig.arguments.get(name, None)): - raise TypeError( - f"Parameter '{name}' cannot have free type variables" - ) - # Build substitution map subs: collections.abc.Mapping[typing.TypeVar, type] = {} for name, param in sig.parameters.items(): @@ -150,13 +135,11 @@ def unify( | typing.TypeVar | types.GenericAlias | types.UnionType - | collections.abc.Mapping | collections.abc.Sequence, subtyp: type | typing.TypeVar | types.UnionType | types.GenericAlias - | collections.abc.Mapping | collections.abc.Sequence, subs: collections.abc.Mapping[typing.TypeVar, type], ) -> collections.abc.Mapping[typing.TypeVar, type]: From c651639b57a02b14c9dd938855fba8d80358e053 Mon Sep 17 00:00:00 2001 From: Eli Date: Sun, 6 Jul 2025 21:58:19 -0400 Subject: [PATCH 31/76] fix none case --- effectful/internals/unification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index f0928ce6..dd1f12b0 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -524,7 +524,7 @@ def nested_type(value) -> type: @nested_type.register -def _(value: type | types.UnionType | types.GenericAlias | types.EllipsisType | types.NoneType) -> type: +def _(value: type | types.UnionType | types.GenericAlias | types.EllipsisType) -> type: return value From a6f0188440c685a719c56970dcdf22fcdad332bf Mon Sep 17 00:00:00 2001 From: Eli Date: Mon, 7 Jul 2025 10:00:27 -0400 Subject: [PATCH 32/76] more aggressive canonicalize applied only during unification --- effectful/internals/unification.py | 42 +++++++------ tests/test_internals_unification.py | 91 ++++++----------------------- 2 files changed, 41 insertions(+), 92 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index dd1f12b0..2eb97062 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -79,6 +79,7 @@ def infer_return_type( ... TypeError: unbound type variables in return type """ + bound_sig.apply_defaults() sig: inspect.Signature = bound_sig.signature # validate that the function has a signature with well-formed type annotations @@ -102,25 +103,30 @@ def infer_return_type( for arg in bound_sig.arguments[name]: subs = unify( canonicalize(param.annotation), - canonicalize(arg), + canonicalize(nested_type(arg)), subs, ) elif param.kind is inspect.Parameter.VAR_KEYWORD: for arg in bound_sig.arguments[name].values(): subs = unify( canonicalize(param.annotation), - canonicalize(arg), + canonicalize(nested_type(arg)), subs, ) else: subs = unify( canonicalize(param.annotation), - canonicalize(bound_sig.arguments[name]), + canonicalize(nested_type(bound_sig.arguments[name])), subs, ) # Apply substitutions to return type - result_type = substitute(canonicalize(sig.return_annotation), subs) + result_type = sig.return_annotation + if typing.get_origin(result_type) is typing.Annotated: + result_type = typing.get_args(result_type)[0] + if result_type is None: + result_type = type(None) + result_type = substitute(result_type, subs) if freetypevars(result_type) and not issubclass( typing.get_origin(result_type), collections.abc.Callable ): @@ -326,14 +332,6 @@ def canonicalize( >>> canonicalize(list) - # Values are converted to their types via nested_type - >>> canonicalize([1, 2, 3]) - collections.abc.Sequence[int] - >>> canonicalize({"key": "value"}) - collections.abc.Mapping[str, str] - >>> canonicalize((1, "hello", 3.14)) - tuple[int, str, float] - # Complex nested canonicalization >>> canonicalize(typing.List[typing.Union[typing.Dict[str, T], None]]) list[dict[str, ~T] | None] @@ -344,6 +342,8 @@ def canonicalize( return canonicalize(typing.Any) elif typ is None: return type(None) + elif typ is Ellipsis: + return types.EllipsisType elif typing.get_origin(typ) in {typing.Union, types.UnionType}: t = canonicalize(typing.get_args(typ)[0]) for arg in typing.get_args(typ)[1:]: @@ -368,21 +368,25 @@ def canonicalize( origin = typing.get_origin(typ) args = typing.get_args(typ) # Regular generic types - canonical_origin = canonicalize(origin) - return canonical_origin[tuple(canonicalize(a) for a in args)] + return canonicalize(origin)[tuple(canonicalize(a) for a in args)] + # normalize built-in containers to abstract collections + elif typ is list: + return collections.abc.Sequence + elif typ is dict: + return collections.abc.Mapping + elif typ is set: + return collections.abc.Set # Handle legacy typing aliases elif hasattr(typing, 'List') and typ is getattr(typing, 'List', None): - return list + return canonicalize(list) elif hasattr(typing, 'Dict') and typ is getattr(typing, 'Dict', None): - return dict + return canonicalize(dict) elif hasattr(typing, 'Set') and typ is getattr(typing, 'Set', None): - return set + return canonicalize(set) elif typ is typing.Callable: return collections.abc.Callable elif typ is typing.Any: return object - elif not isinstance(typ, type) and typing.get_origin(typ) is None: - return canonicalize(nested_type(typ)) else: return typ diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 181f643f..26897a52 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -417,13 +417,13 @@ def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported # Simple generic functions (identity, (int,), {}, int), (identity, (str,), {}, str), - (identity, (list[int],), {}, list[int]), + (identity, (list[int],), {}, collections.abc.Sequence[int]), # Multiple TypeVars (make_pair, (int, str), {}, tuple[int, str]), - (make_pair, (bool, list[float]), {}, tuple[bool, list[float]]), + (make_pair, (bool, list[float]), {}, tuple[bool, collections.abc.Sequence[float]]), # Generic collections (wrap_in_list, (int,), {}, list[int]), - (wrap_in_list, (dict[str, bool],), {}, list[dict[str, bool]]), + (wrap_in_list, (dict[str, bool],), {}, list[collections.abc.Mapping[str, bool]]), (get_first, (list[str],), {}, str), (get_first, (list[tuple[int, float]],), {}, tuple[int, float]), (getitem_mapping, (collections.abc.Mapping[str, int], str), {}, int), @@ -431,11 +431,11 @@ def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported getitem_mapping, (collections.abc.Mapping[bool, list[str]], bool), {}, - list[str], + collections.abc.Sequence[str], ), # Dict operations (dict_values, (dict[str, int],), {}, list[int]), - (dict_values, (dict[bool, list[str]],), {}, list[list[str]]), + (dict_values, (dict[bool, list[str]],), {}, list[collections.abc.Sequence[str]]), # Callable types (process_callable, (collections.abc.Callable[[int], str], int), {}, str), ( @@ -462,7 +462,7 @@ def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported multi_generic, (float, list[float], dict[bool, list[str]]), {}, - tuple[float, bool, list[str]], + tuple[float, bool, collections.abc.Sequence[str]], ), # Same TypeVar used multiple times (same_type_twice, (int, int), {}, int), @@ -569,14 +569,17 @@ def test_infer_return_type_failure( (K, K), (V, V), # Legacy typing aliases are converted to modern forms - (list, list), - (dict, dict), - (set, set), - (list[int], list[int]), - (dict[str, int], dict[str, int]), - (set[bool], set[bool]), - (list[T], list[T]), - (dict[K, V], dict[K, V]), + (list, collections.abc.Sequence), + (dict, collections.abc.Mapping), + (set, collections.abc.Set), + (list[int], collections.abc.Sequence[int]), + (dict[str, int], collections.abc.Mapping[str, int]), + (set[bool], collections.abc.Set[bool]), + (list[T], collections.abc.Sequence[T]), + (dict[K, V], collections.abc.Mapping[K, V]), + # Modern generic aliases pass through with canonicalized args + (list[list[int]], collections.abc.Sequence[collections.abc.Sequence[int]]), # Nested legacy conversion + (dict[str, dict[K, V]], collections.abc.Mapping[str, collections.abc.Mapping[K, V]]), # typing.Callable becomes collections.abc.Callable (typing.Callable, collections.abc.Callable), (typing.Callable[[int], str], collections.abc.Callable[[int], str]), @@ -589,15 +592,9 @@ def test_infer_return_type_failure( (typing.Annotated[int, "metadata"], int), (typing.Annotated[str, "doc string"], str), (typing.Annotated[list[T], "annotation"], list[T]), - (typing.Annotated[dict[K, V], "complex", "multi"], dict[K, V]), + (typing.Annotated[dict[K, V], "complex", "multi"], collections.abc.Mapping[K, V]), # Nested Annotated unwrapping (typing.Annotated[typing.Annotated[int, "inner"], "outer"], int), - # Modern generic aliases pass through with canonicalized args - (list[int], list[int]), - (dict[str, bool], dict[str, bool]), - (set[float], set[float]), - (list[list[int]], list[list[int]]), # Nested legacy conversion - (dict[str, dict[K, V]], dict[str, dict[K, V]]), # Union types are canonicalized with | operator (typing.Union[int, str], int | str), (typing.Union[T, int], T | int), @@ -623,35 +620,6 @@ def test_infer_return_type_failure( (list[dict[set[T], list[V]]], list[dict[set[T], list[V]]]), # Callable with nested canonicalization (typing.Callable[[list[T]], dict[K, V]], collections.abc.Callable[[list[T]], dict[K, V]]), - # Values that will use _nested_type (basic values) - (42, int), - ("hello", str), - (3.14, float), - (True, bool), - (None, type(None)), - # Collections that will use _nested_type - ([1, 2, 3], collections.abc.Sequence[int]), - (["a", "b"], collections.abc.Sequence[str]), - ({"key": "value"}, collections.abc.Mapping[str, str]), - ({1: "one", 2: "two"}, collections.abc.Mapping[int, str]), - ({1, 2, 3}, collections.abc.Set[int]), - ({"x", "y"}, collections.abc.Set[str]), - # Tuples get special handling - ((1, "hello", 3.14), tuple[int, str, float]), - ((1,), tuple[int]), - ((), tuple), - # Empty collections - ([], list), - ({}, dict), - (set(), set), - # Nested collections via _nested_type - ([{1: "one"}, {2: "two"}], collections.abc.Sequence[collections.abc.Mapping[int, str]]), - ({"key": [1, 2, 3]}, collections.abc.Mapping[str, collections.abc.Sequence[int]]), - ([(1, "a"), (2, "b")], collections.abc.Sequence[tuple[int, str]]), - # Functions/callables - (lambda x: x, type(lambda x: x)), - # bytes handling - (b"hello", bytes), ], ids=str, ) @@ -808,29 +776,6 @@ class ConcreteInterpretation(dict): assert result == dict # Empty dict should return dict, not Interpretation -def test_canonicalize_with_nested_type_integration(): - """Test that canonicalize properly uses _nested_type for values""" - # Test various values go through _nested_type in canonicalize - assert canonicalize([1, 2, 3]) == collections.abc.Sequence[int] - assert canonicalize({"a": 1}) == collections.abc.Mapping[str, int] - assert canonicalize({1, 2}) == collections.abc.Set[int] - assert canonicalize((1, "x")) == tuple[int, str] - - # Nested case - nested = [{"a": (1, 2)}, {"b": (3, 4)}] - expected = collections.abc.Sequence[collections.abc.Mapping[str, tuple[int, int]]] - assert canonicalize(nested) == expected - - -def test_canonicalize_preserves_modern_syntax(): - """Test that already-modern syntax is preserved""" - # These should pass through unchanged - assert canonicalize(list[int]) == list[int] - assert canonicalize(dict[str, int]) == dict[str, int] - assert canonicalize(int | str) == int | str - assert canonicalize(list[T] | None) == list[T] | type(None) - - def test_canonicalize_union_ordering(): """Test union type canonicalization handles ordering consistently""" # The | operator may reorder types, but the result should be equivalent From e79b5d3277dea60b4a7f243602b885a335490273 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 10:01:48 -0400 Subject: [PATCH 33/76] update canonicalize test --- tests/test_internals_unification.py | 27 +++++++++++---------------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 26897a52..6aba0b0d 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -560,10 +560,10 @@ def test_infer_return_type_failure( (str, str), (float, float), (bool, bool), - (list, list), - (dict, dict), - (set, set), (tuple, tuple), + (list, collections.abc.Sequence), + (dict, collections.abc.Mapping), + (set, collections.abc.Set), # TypeVars are preserved (T, T), (K, K), @@ -591,35 +591,30 @@ def test_infer_return_type_failure( # Annotated types are unwrapped (typing.Annotated[int, "metadata"], int), (typing.Annotated[str, "doc string"], str), - (typing.Annotated[list[T], "annotation"], list[T]), + (typing.Annotated[list[T], "annotation"], collections.abc.Sequence[T]), (typing.Annotated[dict[K, V], "complex", "multi"], collections.abc.Mapping[K, V]), # Nested Annotated unwrapping (typing.Annotated[typing.Annotated[int, "inner"], "outer"], int), # Union types are canonicalized with | operator (typing.Union[int, str], int | str), (typing.Union[T, int], T | int), - (typing.Union[list[T], dict[K, V]], list[T] | dict[K, V]), + (typing.Union[list[T], dict[K, V]], collections.abc.Sequence[T] | collections.abc.Mapping[K, V]), (typing.Union[int, str, bool], int | str | bool), # Nested unions - (typing.Union[list[int], None], list[int] | type(None)), - (typing.Union[dict[K, V], set[T]], dict[K, V] | set[T]), + (typing.Union[list[int], None], collections.abc.Sequence[int] | type(None)), + (typing.Union[dict[K, V], set[T]], collections.abc.Mapping[K, V] | collections.abc.Set[T]), # Complex nested canonicalization - (list[dict[str, T] | None], list[dict[str, T] | type(None)]), - (dict[K, list[V | None]], dict[K, list[V | type(None)]]), - # Already canonical forms pass through - (list[T], list[T]), - (dict[K, V], dict[K, V]), - (int | str, int | str), - (list[int] | None, list[int] | type(None)), + (list[dict[str, T] | None], collections.abc.Sequence[collections.abc.Mapping[str, T] | type(None)]), + (dict[K, list[V | None]], collections.abc.Mapping[K, collections.abc.Sequence[V | type(None)]]), # None type handling (type(None), type(None)), (typing.Union[int, None], int | type(None)), # Ellipsis type (type(...), type(...)), # Generic aliases with multiple levels of nesting - (list[dict[set[T], list[V]]], list[dict[set[T], list[V]]]), + (list[dict[set[T], list[V]]], collections.abc.Sequence[collections.abc.Mapping[collections.abc.Set[T], collections.abc.Sequence[V]]]), # Callable with nested canonicalization - (typing.Callable[[list[T]], dict[K, V]], collections.abc.Callable[[list[T]], dict[K, V]]), + (typing.Callable[[list[T]], dict[K, V]], collections.abc.Callable[[collections.abc.Sequence[T]], collections.abc.Mapping[K, V]]), ], ids=str, ) From ef342472152e9538769abfa61e25f73e6b0094b2 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 10:13:21 -0400 Subject: [PATCH 34/76] remove dead path --- effectful/internals/unification.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 2eb97062..d4f21b42 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -519,8 +519,6 @@ def nested_type(value) -> type: if isinstance(value, Term): raise TypeError(f"Terms should not appear in nested_type, but got {value}") - elif value is None: - return type(None) elif not isinstance(value, type) and typing.get_origin(value) is None: return type(value) else: From 748bcac1a4bf1e88935aa40fe75dc162ef808eaf Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 10:37:33 -0400 Subject: [PATCH 35/76] paramspec failure --- effectful/internals/unification.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index d4f21b42..a66c1a6b 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -127,9 +127,7 @@ def infer_return_type( if result_type is None: result_type = type(None) result_type = substitute(result_type, subs) - if freetypevars(result_type) and not issubclass( - typing.get_origin(result_type), collections.abc.Callable - ): + if freetypevars(result_type) and typing.get_origin(result_type) is not collections.abc.Callable: raise TypeError( "Return type cannot have free type variables after substitution" ) @@ -230,6 +228,9 @@ def unify( if typ in subs: subs = unify(subs[typ], subtyp, subs) return {**subs, **{typ: subtyp}} + elif isinstance(typ, typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs) or \ + isinstance(subtyp, typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs): + raise TypeError("ParamSpec handling is not implemented") elif typing.get_origin(typ) in {typing.Union, types.UnionType} or \ typing.get_origin(subtyp) in {typing.Union, types.UnionType}: # TODO handle UnionType properly From 08d149f59457108f8e82d925aaf37a3f1132802f Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 11:01:20 -0400 Subject: [PATCH 36/76] literal and optional --- effectful/internals/unification.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index a66c1a6b..c399c2f5 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -235,7 +235,8 @@ def unify( typing.get_origin(subtyp) in {typing.Union, types.UnionType}: # TODO handle UnionType properly return unify(typing.get_args(typ), typing.get_args(subtyp), subs) - elif typing.get_args(typ) and typing.get_args(subtyp): + elif isinstance(typ, typing._GenericAlias | types.GenericAlias) and \ + isinstance(subtyp, typing._GenericAlias | types.GenericAlias): subs = unify(typing.get_origin(typ), typing.get_origin(subtyp), subs) return unify(typing.get_args(typ), typing.get_args(subtyp), subs) elif isinstance(typ, list | tuple) and isinstance(subtyp, list | tuple) and len(typ) == len(subtyp): @@ -364,6 +365,13 @@ def canonicalize( else: # Handle other Callable formats return origin[tuple(canonicalize(a) for a in args)] + elif typing.get_origin(typ) is typing.Literal: + t = type(typing.get_args(typ)[0]) + for arg in typing.get_args(typ)[1:]: + t = t | type(arg) + return canonicalize(t) + elif typing.get_origin(typ) is typing.Optional: + return canonicalize(None | typing.get_args(typ)[0]) elif isinstance(typ, typing._GenericAlias | types.GenericAlias) and typing.get_origin(typ) is not typ: # type: ignore # Handle generic types origin = typing.get_origin(typ) From 0efe01da0d3a1c65003946d6023f7ee7d9d462ba Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 11:04:03 -0400 Subject: [PATCH 37/76] paramspec canonicalize --- effectful/internals/unification.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index c399c2f5..f9f25f03 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -353,6 +353,8 @@ def canonicalize( return t elif isinstance(typ, typing.TypeVar): return typ + elif isinstance(typ, typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs): + return typ elif typing.get_origin(typ) is collections.abc.Callable: origin, args = typing.get_origin(typ), typing.get_args(typ) if not args: From b52c7dfa8626be052e035cbafe44096de380bf71 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 11:11:13 -0400 Subject: [PATCH 38/76] simplify alias --- effectful/internals/unification.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index f9f25f03..bbc24de9 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -388,11 +388,13 @@ def canonicalize( elif typ is set: return collections.abc.Set # Handle legacy typing aliases - elif hasattr(typing, 'List') and typ is getattr(typing, 'List', None): + elif typ is typing.Tuple: + return canonicalize(tuple) + elif typ is typing.List: return canonicalize(list) - elif hasattr(typing, 'Dict') and typ is getattr(typing, 'Dict', None): + elif typ is typing.Dict: return canonicalize(dict) - elif hasattr(typing, 'Set') and typ is getattr(typing, 'Set', None): + elif typ is typing.Set: return canonicalize(set) elif typ is typing.Callable: return collections.abc.Callable From c5f0f1b0a5ac1c08089037f95e0b932d1740198e Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 11:25:22 -0400 Subject: [PATCH 39/76] single unify call --- effectful/internals/unification.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index bbc24de9..62acaadd 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -97,28 +97,24 @@ def infer_return_type( raise TypeError("unbound type variables in return type") # Build substitution map - subs: collections.abc.Mapping[typing.TypeVar, type] = {} + arg_annos = [] + arg_types = [] for name, param in sig.parameters.items(): if param.kind is inspect.Parameter.VAR_POSITIONAL: for arg in bound_sig.arguments[name]: - subs = unify( - canonicalize(param.annotation), - canonicalize(nested_type(arg)), - subs, - ) + arg_annos += [param.annotation] + arg_types += [arg] elif param.kind is inspect.Parameter.VAR_KEYWORD: for arg in bound_sig.arguments[name].values(): - subs = unify( - canonicalize(param.annotation), - canonicalize(nested_type(arg)), - subs, - ) + arg_annos += [param.annotation] + arg_types += [arg] else: - subs = unify( - canonicalize(param.annotation), - canonicalize(nested_type(bound_sig.arguments[name])), - subs, - ) + arg_annos += [param.annotation] + arg_types += [bound_sig.arguments[name]] + + arg_annos = [canonicalize(a) for a in arg_annos] + arg_types = [canonicalize(nested_type(a)) for a in arg_types] + subs = unify(arg_annos, arg_types, {}) # Apply substitutions to return type result_type = sig.return_annotation From bd006b8e0ae1e053c347da996aa6347aaa69e065 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 11:42:56 -0400 Subject: [PATCH 40/76] fast paths --- effectful/internals/unification.py | 30 ++++++++++++++++------------- tests/test_internals_unification.py | 5 ----- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 62acaadd..f712c74b 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -82,10 +82,19 @@ def infer_return_type( bound_sig.apply_defaults() sig: inspect.Signature = bound_sig.signature - # validate that the function has a signature with well-formed type annotations - if sig.return_annotation is inspect.Signature.empty: - raise TypeError("Function must have a return type annotation") + return_anno = sig.return_annotation + if typing.get_origin(return_anno) is typing.Annotated: + return_anno = typing.get_args(return_anno)[0] + + # fast path for simple cases + if return_anno is inspect.Signature.empty: + return object + elif return_anno is None: + return type(None) + elif not freetypevars(return_anno): + return return_anno + # validate that the function has a signature with well-formed type annotations result_fvs: set[typing.TypeVar] = freetypevars(sig.return_annotation) pattern_fvs: set[typing.TypeVar] = set().union(*(freetypevars(p.annotation) for p in sig.parameters.values())) concrete_fvs: set[typing.TypeVar] = set().union(*(freetypevars(arg) for arg in bound_sig.arguments.values())) @@ -114,20 +123,15 @@ def infer_return_type( arg_annos = [canonicalize(a) for a in arg_annos] arg_types = [canonicalize(nested_type(a)) for a in arg_types] - subs = unify(arg_annos, arg_types, {}) + subs = unify(arg_annos, arg_types) # Apply substitutions to return type - result_type = sig.return_annotation - if typing.get_origin(result_type) is typing.Annotated: - result_type = typing.get_args(result_type)[0] - if result_type is None: - result_type = type(None) - result_type = substitute(result_type, subs) - if freetypevars(result_type) and typing.get_origin(result_type) is not collections.abc.Callable: + return_anno = substitute(return_anno, subs) + if freetypevars(return_anno) and typing.get_origin(return_anno) is not collections.abc.Callable: raise TypeError( "Return type cannot have free type variables after substitution" ) - return result_type + return return_anno def unify( @@ -141,7 +145,7 @@ def unify( | types.UnionType | types.GenericAlias | collections.abc.Sequence, - subs: collections.abc.Mapping[typing.TypeVar, type], + subs: collections.abc.Mapping[typing.TypeVar, type] = {}, ) -> collections.abc.Mapping[typing.TypeVar, type]: """ Unify a pattern type with a concrete type, returning a substitution map. diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 6aba0b0d..8cea3e81 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -526,11 +526,6 @@ def no_param_annotation(x) -> T: # No parameter annotation {}, ), # Missing annotations - ( - no_return_annotation, - (int,), - {}, - ), ( no_param_annotation, (int,), From 84e87347084cef49f41a116b115a0eeebb7c927b Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 12:02:27 -0400 Subject: [PATCH 41/76] union handling --- effectful/internals/unification.py | 20 ++++++++++++++++---- tests/test_internals_unification.py | 18 ------------------ 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index f712c74b..b6beffc2 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -231,10 +231,22 @@ def unify( elif isinstance(typ, typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs) or \ isinstance(subtyp, typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs): raise TypeError("ParamSpec handling is not implemented") - elif typing.get_origin(typ) in {typing.Union, types.UnionType} or \ - typing.get_origin(subtyp) in {typing.Union, types.UnionType}: - # TODO handle UnionType properly - return unify(typing.get_args(typ), typing.get_args(subtyp), subs) + elif typing.get_origin(typ) in {typing.Union, types.UnionType}: + any_succeeded = False + for arg in typing.get_args(typ): + try: + subs = unify(arg, subtyp, subs) + any_succeeded = True + except TypeError: + continue + if any_succeeded: + return subs + else: + raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") + elif typing.get_origin(subtyp) in {typing.Union, types.UnionType}: + for arg in typing.get_args(subtyp): + subs = unify(typ, arg, subs) + return subs elif isinstance(typ, typing._GenericAlias | types.GenericAlias) and \ isinstance(subtyp, typing._GenericAlias | types.GenericAlias): subs = unify(typing.get_origin(typ), typing.get_origin(subtyp), subs) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 8cea3e81..84b6e24d 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -279,21 +279,6 @@ def test_substitute( {}, {T: int, U: str, V: bool}, ), - # Union types - basic element-wise unification (current implementation) - # Note: Current unify treats union args as sequences, not true union logic - ( - T | V, - int | str, - {}, - {T: int, V: str}, - ), # Element-wise unification of TypeVars - (T | V, int | str, {}, {T: int, V: str}), # typing.Union syntax - # Simple union compatibility - TypeVar gets unified with itself - (T | int, T | int, {}, {T: T}), # Identical unions - T unifies with T - (T | int, T | int, {}, {T: T}), # Identical typing.Union - # Sequence unification (tuples as sequences) - ((T, V), (int, str), {}, {T: int, V: str}), - ([T, V], [int, str], {}, {T: int, V: str}), # Complex combinations ( dict[K, collections.abc.Callable[[T], V]], @@ -335,9 +320,6 @@ def test_unify_success( # Sequence length mismatch ((T, V), (int,)), ([T, V], [int, str, bool]), - # Union failure cases - (T | int, V | str), # typing.Union mismatch - (T | V, int | str | bool), # Different union sizes ], ids=str, ) From a6426ae734b21338d746057a49a13af54f77e07c Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 12:25:04 -0400 Subject: [PATCH 42/76] freshen --- effectful/internals/unification.py | 44 +++++++++++++++++++----------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index b6beffc2..e988313d 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -1,6 +1,7 @@ import collections.abc import functools import inspect +import random import types import typing @@ -94,17 +95,6 @@ def infer_return_type( elif not freetypevars(return_anno): return return_anno - # validate that the function has a signature with well-formed type annotations - result_fvs: set[typing.TypeVar] = freetypevars(sig.return_annotation) - pattern_fvs: set[typing.TypeVar] = set().union(*(freetypevars(p.annotation) for p in sig.parameters.values())) - concrete_fvs: set[typing.TypeVar] = set().union(*(freetypevars(arg) for arg in bound_sig.arguments.values())) - if (result_fvs | pattern_fvs) & concrete_fvs: - raise TypeError( - "Cannot unify free type variables in pattern and concrete types" - ) - if not result_fvs <= pattern_fvs: - raise TypeError("unbound type variables in return type") - # Build substitution map arg_annos = [] arg_types = [] @@ -122,18 +112,40 @@ def infer_return_type( arg_types += [bound_sig.arguments[name]] arg_annos = [canonicalize(a) for a in arg_annos] - arg_types = [canonicalize(nested_type(a)) for a in arg_types] + arg_types = [freshen(canonicalize(nested_type(a))) for a in arg_types] subs = unify(arg_annos, arg_types) # Apply substitutions to return type return_anno = substitute(return_anno, subs) - if freetypevars(return_anno) and typing.get_origin(return_anno) is not collections.abc.Callable: - raise TypeError( - "Return type cannot have free type variables after substitution" - ) + if isinstance(return_anno, typing.TypeVar): + raise TypeError(f"Unbound type variable {return_anno} in return type") return return_anno +def freshen(tp): + """ + Return a freshened version of the given type expression. + + This function replaces all TypeVars in the type expression with new TypeVars + that have unique names, ensuring that the resulting type has no free TypeVars. + It is useful for creating fresh type variables in generic programming contexts. + + Args: + tp: The type expression to freshen. Can be a plain type, TypeVar, + generic alias, or union type. + + Returns: + A new type expression with all TypeVars replaced by fresh TypeVars. + + Examples: + >>> import typing + >>> T = typing.TypeVar('T') + >>> freshen(T) + ~T_12345678 # Example output with a random suffix + """ + return substitute(tp, {fv: typing.TypeVar(name=f"{fv.__name__}_{random.randint(0, 1 << 32)}") for fv in freetypevars(tp)}) + + def unify( typ: type | typing.TypeVar From 3a8f726e8322b4054df69386729798c07c7eb265 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 12:38:17 -0400 Subject: [PATCH 43/76] type of freshen --- effectful/internals/unification.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index e988313d..19fce04b 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -122,7 +122,7 @@ def infer_return_type( return return_anno -def freshen(tp): +def freshen(tp: type | typing.TypeVar | types.GenericAlias | types.UnionType): """ Return a freshened version of the given type expression. @@ -143,7 +143,16 @@ def freshen(tp): >>> freshen(T) ~T_12345678 # Example output with a random suffix """ - return substitute(tp, {fv: typing.TypeVar(name=f"{fv.__name__}_{random.randint(0, 1 << 32)}") for fv in freetypevars(tp)}) + return substitute(tp, { + fv: typing.TypeVar( + name=f"{fv.__name__}_{random.randint(0, 1 << 32)}", + bound=fv.__bound__, + default=fv.__default__, + covariant=fv.__covariant__, + contravariant=fv.__contravariant__, + ) + for fv in freetypevars(tp) + }) def unify( @@ -240,6 +249,10 @@ def unify( if typ in subs: subs = unify(subs[typ], subtyp, subs) return {**subs, **{typ: subtyp}} + elif isinstance(subtyp, typing.TypeVar): + if subtyp in subs: + subs = unify(typ, subs[subtyp], subs) + return {**subs, **{subtyp: typ}} elif isinstance(typ, typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs) or \ isinstance(subtyp, typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs): raise TypeError("ParamSpec handling is not implemented") @@ -703,7 +716,7 @@ def freetypevars( def substitute( typ: type | types.GenericAlias | types.UnionType, - subs: collections.abc.Mapping[typing.TypeVar, type], + subs: collections.abc.Mapping[typing.TypeVar, type | typing.TypeVar], ) -> type | types.GenericAlias | types.UnionType: """ Substitute type variables in a type expression with concrete types. From 4639ce073ddc25991d18dab3fd5cb26ffac8ce72 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 12:38:53 -0400 Subject: [PATCH 44/76] reorder --- effectful/internals/unification.py | 66 +++++++++++++++--------------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 19fce04b..b5cdc8fa 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -122,39 +122,6 @@ def infer_return_type( return return_anno -def freshen(tp: type | typing.TypeVar | types.GenericAlias | types.UnionType): - """ - Return a freshened version of the given type expression. - - This function replaces all TypeVars in the type expression with new TypeVars - that have unique names, ensuring that the resulting type has no free TypeVars. - It is useful for creating fresh type variables in generic programming contexts. - - Args: - tp: The type expression to freshen. Can be a plain type, TypeVar, - generic alias, or union type. - - Returns: - A new type expression with all TypeVars replaced by fresh TypeVars. - - Examples: - >>> import typing - >>> T = typing.TypeVar('T') - >>> freshen(T) - ~T_12345678 # Example output with a random suffix - """ - return substitute(tp, { - fv: typing.TypeVar( - name=f"{fv.__name__}_{random.randint(0, 1 << 32)}", - bound=fv.__bound__, - default=fv.__default__, - covariant=fv.__covariant__, - contravariant=fv.__contravariant__, - ) - for fv in freetypevars(tp) - }) - - def unify( typ: type | typing.TypeVar @@ -286,6 +253,39 @@ def unify( raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") +def freshen(tp: type | typing.TypeVar | types.GenericAlias | types.UnionType): + """ + Return a freshened version of the given type expression. + + This function replaces all TypeVars in the type expression with new TypeVars + that have unique names, ensuring that the resulting type has no free TypeVars. + It is useful for creating fresh type variables in generic programming contexts. + + Args: + tp: The type expression to freshen. Can be a plain type, TypeVar, + generic alias, or union type. + + Returns: + A new type expression with all TypeVars replaced by fresh TypeVars. + + Examples: + >>> import typing + >>> T = typing.TypeVar('T') + >>> freshen(T) + ~T_12345678 # Example output with a random suffix + """ + return substitute(tp, { + fv: typing.TypeVar( + name=f"{fv.__name__}_{random.randint(0, 1 << 32)}", + bound=fv.__bound__, + default=fv.__default__, + covariant=fv.__covariant__, + contravariant=fv.__contravariant__, + ) + for fv in freetypevars(tp) + }) + + def canonicalize( typ: type | typing.TypeVar | types.GenericAlias | types.UnionType, ) -> type: From ab7b1ffc558d158a84896044dbb1079557d5e776 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 14:22:00 -0400 Subject: [PATCH 45/76] fixpoint in substitute --- effectful/internals/unification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index b5cdc8fa..2fba4d22 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -763,7 +763,7 @@ def substitute( """ if isinstance(typ, typing.TypeVar): - return subs.get(typ, typ) + return substitute(subs[typ], subs) if typ in subs else typ elif isinstance(typ, list | tuple): # Handle plain lists/sequences (e.g., in Callable's parameter list) return type(typ)(substitute(item, subs) for item in typ) From de9e7b68a08d792dd2938eb1412bac391db65e77 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 14:26:58 -0400 Subject: [PATCH 46/76] tweak types --- effectful/internals/unification.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 2fba4d22..0e4cc2b6 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -117,19 +117,19 @@ def infer_return_type( # Apply substitutions to return type return_anno = substitute(return_anno, subs) - if isinstance(return_anno, typing.TypeVar): - raise TypeError(f"Unbound type variable {return_anno} in return type") return return_anno def unify( typ: type | typing.TypeVar + | typing.ParamSpec | types.GenericAlias | types.UnionType | collections.abc.Sequence, subtyp: type | typing.TypeVar + | typing.ParamSpec | types.UnionType | types.GenericAlias | collections.abc.Sequence, @@ -649,8 +649,8 @@ def _(value: range) -> type: def freetypevars( - typ: type | typing.TypeVar | types.GenericAlias | types.UnionType | types.NoneType, -) -> set[typing.TypeVar]: + typ: type | typing.TypeVar | types.GenericAlias | types.UnionType | types.NoneType | typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs, +) -> set[typing.TypeVar | typing.ParamSpec]: """ Return a set of free type variables in the given type expression. @@ -700,8 +700,10 @@ def freetypevars( >>> freetypevars(dict[str, T]) {~T} """ - if isinstance(typ, typing.TypeVar): + if isinstance(typ, typing.TypeVar | typing.ParamSpec): return {typ} + elif isinstance(typ, typing.ParamSpecArgs | typing.ParamSpecKwargs): + return freetypevars(typing.get_origin(typ)) elif typing.get_origin(typ) is typing.Annotated: return freetypevars(typing.get_args(typ)[0]) elif isinstance(typ, list | tuple): From 9dc9acdc1c73ff1d77bb554c4cebdae3c0397b55 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 14:43:29 -0400 Subject: [PATCH 47/76] empty params --- effectful/internals/unification.py | 4 ++++ tests/test_internals_unification.py | 6 ------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 0e4cc2b6..08012321 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -94,6 +94,8 @@ def infer_return_type( return type(None) elif not freetypevars(return_anno): return return_anno + elif any(p.annotation is inspect.Signature.empty for p in sig.parameters.values()): + raise TypeError("Function has parameters without type annotations, cannot infer return type") # Build substitution map arg_annos = [] @@ -424,6 +426,8 @@ def canonicalize( return collections.abc.Mapping elif typ is set: return collections.abc.Set + elif issubclass(typ, range): + return collections.abc.Sequence[int] # Handle legacy typing aliases elif typ is typing.Tuple: return canonicalize(tuple) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 84b6e24d..c7b7fb27 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -501,12 +501,6 @@ def no_param_annotation(x) -> T: # No parameter annotation @pytest.mark.parametrize( "func,args,kwargs", [ - # Unbound type variable in return - ( - unbound_typevar_func, - (int,), - {}, - ), # Missing annotations ( no_param_annotation, From 7fd9927e2b73d8f027a98251112fef720e6843ce Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 14:50:08 -0400 Subject: [PATCH 48/76] tweak substitute type --- effectful/internals/unification.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 08012321..98208c4a 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -135,8 +135,8 @@ def unify( | types.UnionType | types.GenericAlias | collections.abc.Sequence, - subs: collections.abc.Mapping[typing.TypeVar, type] = {}, -) -> collections.abc.Mapping[typing.TypeVar, type]: + subs: collections.abc.Mapping[typing.TypeVar, type | typing.TypeVar | typing.ParamSpec] = {}, +) -> collections.abc.Mapping[typing.TypeVar, type | typing.TypeVar | typing.ParamSpec]: """ Unify a pattern type with a concrete type, returning a substitution map. @@ -722,7 +722,7 @@ def freetypevars( def substitute( typ: type | types.GenericAlias | types.UnionType, - subs: collections.abc.Mapping[typing.TypeVar, type | typing.TypeVar], + subs: collections.abc.Mapping[typing.TypeVar, type | typing.ParamSpec | typing.TypeVar], ) -> type | types.GenericAlias | types.UnionType: """ Substitute type variables in a type expression with concrete types. From 7330f23647c61bbc8ef250bd77fcdceb187300a3 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 14:57:55 -0400 Subject: [PATCH 49/76] doctest --- effectful/internals/unification.py | 53 +++++++++--------------------- 1 file changed, 16 insertions(+), 37 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 98208c4a..3c47141a 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -70,15 +70,6 @@ def infer_return_type( >>> bound = sig.bind() >>> infer_return_type(bound) - - >>> # Error: unbound type variable in return - >>> def bad_func(x: T) -> tuple[T, K]: ... # K not in parameters - >>> sig = inspect.signature(bad_func) - >>> bound = sig.bind(int) - >>> infer_return_type(bound) # doctest: +ELLIPSIS - Traceback (most recent call last): - ... - TypeError: unbound type variables in return type """ bound_sig.apply_defaults() sig: inspect.Signature = bound_sig.signature @@ -273,8 +264,10 @@ def freshen(tp: type | typing.TypeVar | types.GenericAlias | types.UnionType): Examples: >>> import typing >>> T = typing.TypeVar('T') - >>> freshen(T) - ~T_12345678 # Example output with a random suffix + >>> isinstance(freshen(T), typing.TypeVar) + True + >>> freshen(T) == T + False """ return substitute(tp, { fv: typing.TypeVar( @@ -321,40 +314,38 @@ def canonicalize( >>> K = typing.TypeVar('K') >>> V = typing.TypeVar('V') - # Legacy typing aliases are converted to modern forms - >>> canonicalize(typing.List[int]) - list[int] - >>> canonicalize(typing.Dict[str, int]) - dict[str, int] - >>> canonicalize(typing.Set[bool]) - set[bool] - >>> canonicalize(typing.Callable[[int], str]) - collections.abc.Callable[[int], str] + # Plain types pass through unchanged + >>> canonicalize(int) + + >>> canonicalize(str) + # TypeVars are preserved unchanged >>> canonicalize(T) ~T + + # Containers are normalized to abstract collections >>> canonicalize(list[T]) - list[~T] + collections.abc.Sequence[~T] # Annotated types are unwrapped >>> canonicalize(typing.Annotated[int, "metadata"]) >>> canonicalize(typing.Annotated[list[str], "doc string"]) - list[str] + collections.abc.Sequence[str] # Nested generic types are recursively canonicalized >>> canonicalize(typing.List[typing.Dict[K, V]]) - list[dict[~K, ~V]] + collections.abc.Sequence[collections.abc.Mapping[~K, ~V]] >>> canonicalize(typing.Dict[str, typing.List[T]]) - dict[str, list[~T]] + collections.abc.Mapping[str, collections.abc.Sequence[~T]] # Union types are canonicalized with | operator >>> result = canonicalize(typing.Union[int, str]) >>> result == int | str True >>> result = canonicalize(typing.Union[list[T], dict[K, V]]) - >>> result == list[T] | dict[K, V] + >>> result == collections.abc.Sequence[T] | collections.abc.Mapping[K, V] True # typing.Any becomes object @@ -364,18 +355,6 @@ def canonicalize( # inspect.Parameter.empty becomes object (via Any) >>> canonicalize(inspect.Parameter.empty) - - # Plain types pass through unchanged - >>> canonicalize(int) - - >>> canonicalize(str) - - >>> canonicalize(list) - - - # Complex nested canonicalization - >>> canonicalize(typing.List[typing.Union[typing.Dict[str, T], None]]) - list[dict[str, ~T] | None] """ if typing.get_origin(typ) is typing.Annotated: return canonicalize(typing.get_args(typ)[0]) From acf9f11f955649028d83ce2e9be1d12c94f0c5ff Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 15:07:14 -0400 Subject: [PATCH 50/76] separate nested_type --- effectful/internals/unification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 3c47141a..178d03ce 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -105,7 +105,7 @@ def infer_return_type( arg_types += [bound_sig.arguments[name]] arg_annos = [canonicalize(a) for a in arg_annos] - arg_types = [freshen(canonicalize(nested_type(a))) for a in arg_types] + arg_types = [freshen(canonicalize(a)) for a in arg_types] subs = unify(arg_annos, arg_types) # Apply substitutions to return type From 9ec83f82533cbe44d02ff104052078f7bc0e26b9 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 8 Jul 2025 15:30:01 -0400 Subject: [PATCH 51/76] handle defaults --- effectful/internals/unification.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 178d03ce..3648f715 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -71,7 +71,6 @@ def infer_return_type( >>> infer_return_type(bound) """ - bound_sig.apply_defaults() sig: inspect.Signature = bound_sig.signature return_anno = sig.return_annotation @@ -92,7 +91,9 @@ def infer_return_type( arg_annos = [] arg_types = [] for name, param in sig.parameters.items(): - if param.kind is inspect.Parameter.VAR_POSITIONAL: + if name not in bound_sig.arguments: + continue + elif param.kind is inspect.Parameter.VAR_POSITIONAL: for arg in bound_sig.arguments[name]: arg_annos += [param.annotation] arg_types += [arg] @@ -405,7 +406,7 @@ def canonicalize( return collections.abc.Mapping elif typ is set: return collections.abc.Set - elif issubclass(typ, range): + elif isinstance(typ, type) and issubclass(typ, range): return collections.abc.Sequence[int] # Handle legacy typing aliases elif typ is typing.Tuple: From d7fc53bd3f9e83ca7b56db267e7613a1f58e2a9f Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 9 Jul 2025 20:58:40 -0400 Subject: [PATCH 52/76] update --- effectful/internals/unification.py | 509 ++++++++++++---------------- tests/test_internals_unification.py | 147 ++------ 2 files changed, 240 insertions(+), 416 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 3648f715..85f7dda4 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -6,129 +6,14 @@ import typing -def infer_return_type( - bound_sig: inspect.BoundArguments, -) -> type | types.GenericAlias | types.UnionType: - """ - Infer the return type of a function based on its signature and argument types. - - This function takes a BoundArguments object (created by binding concrete argument - types to a function signature) and infers what the return type should be by: - 1. Finding all TypeVars in the function's parameter and return annotations - 2. Unifying the parameter type annotations with the concrete argument types - 3. Applying the resulting TypeVar substitutions to the return type annotation - - The function ensures that all type variables in the return type can be inferred - from the parameter types (no unbound type variables in the return). - - Args: - bound_sig: A BoundArguments object obtained by calling - inspect.signature(func).bind(*arg_types, **kwarg_types) - where arg_types and kwarg_types are concrete types - - Returns: - The inferred return type with all TypeVars substituted with concrete types +Substitutions = collections.abc.Mapping[ + typing.TypeVar | typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs, + type | typing.TypeVar | typing.ParamSpec | collections.abc.Sequence | collections.abc.Mapping[str, typing.Any] +] - Raises: - TypeError: If the function lacks required type annotations, has unbound - type variables in the return type, if unification fails, - if the function uses variadic parameters (*args, **kwargs), - or if parameters have free type variables. - - Examples: - >>> import inspect - >>> import typing - >>> from effectful.internals.unification import infer_return_type - >>> T = typing.TypeVar('T') - >>> K = typing.TypeVar('K') - >>> V = typing.TypeVar('V') - >>> # Simple generic function - >>> def identity(x: T) -> T: ... - >>> sig = inspect.signature(identity) - >>> bound = sig.bind(int) - >>> infer_return_type(bound) - - - >>> # Function with multiple TypeVars - >>> def make_dict(key: K, value: V) -> dict[K, V]: ... - >>> sig = inspect.signature(make_dict) - >>> bound = sig.bind(str, int) - >>> infer_return_type(bound) - dict[str, int] - - >>> # Function with nested generics - >>> def wrap_in_list(x: T) -> list[T]: ... - >>> sig = inspect.signature(wrap_in_list) - >>> bound = sig.bind(bool) - >>> infer_return_type(bound) - list[bool] - - >>> # Function with no TypeVars - >>> def get_int() -> int: ... - >>> sig = inspect.signature(get_int) - >>> bound = sig.bind() - >>> infer_return_type(bound) - - """ - sig: inspect.Signature = bound_sig.signature - - return_anno = sig.return_annotation - if typing.get_origin(return_anno) is typing.Annotated: - return_anno = typing.get_args(return_anno)[0] - - # fast path for simple cases - if return_anno is inspect.Signature.empty: - return object - elif return_anno is None: - return type(None) - elif not freetypevars(return_anno): - return return_anno - elif any(p.annotation is inspect.Signature.empty for p in sig.parameters.values()): - raise TypeError("Function has parameters without type annotations, cannot infer return type") - - # Build substitution map - arg_annos = [] - arg_types = [] - for name, param in sig.parameters.items(): - if name not in bound_sig.arguments: - continue - elif param.kind is inspect.Parameter.VAR_POSITIONAL: - for arg in bound_sig.arguments[name]: - arg_annos += [param.annotation] - arg_types += [arg] - elif param.kind is inspect.Parameter.VAR_KEYWORD: - for arg in bound_sig.arguments[name].values(): - arg_annos += [param.annotation] - arg_types += [arg] - else: - arg_annos += [param.annotation] - arg_types += [bound_sig.arguments[name]] - - arg_annos = [canonicalize(a) for a in arg_annos] - arg_types = [freshen(canonicalize(a)) for a in arg_types] - subs = unify(arg_annos, arg_types) - - # Apply substitutions to return type - return_anno = substitute(return_anno, subs) - return return_anno - - -def unify( - typ: type - | typing.TypeVar - | typing.ParamSpec - | types.GenericAlias - | types.UnionType - | collections.abc.Sequence, - subtyp: type - | typing.TypeVar - | typing.ParamSpec - | types.UnionType - | types.GenericAlias - | collections.abc.Sequence, - subs: collections.abc.Mapping[typing.TypeVar, type | typing.TypeVar | typing.ParamSpec] = {}, -) -> collections.abc.Mapping[typing.TypeVar, type | typing.TypeVar | typing.ParamSpec]: +@functools.singledispatch +def unify(typ, subtyp, subs: Substitutions = {}) -> Substitutions: """ Unify a pattern type with a concrete type, returning a substitution map. @@ -206,45 +91,183 @@ def unify( >>> unify((T, V), (int, str), {}) {~T: , ~V: } """ - if isinstance(typ, typing.TypeVar): - if typ in subs: - subs = unify(subs[typ], subtyp, subs) - return {**subs, **{typ: subtyp}} - elif isinstance(subtyp, typing.TypeVar): - if subtyp in subs: - subs = unify(typ, subs[subtyp], subs) - return {**subs, **{subtyp: typ}} - elif isinstance(typ, typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs) or \ - isinstance(subtyp, typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs): - raise TypeError("ParamSpec handling is not implemented") - elif typing.get_origin(typ) in {typing.Union, types.UnionType}: - any_succeeded = False - for arg in typing.get_args(typ): - try: - subs = unify(arg, subtyp, subs) - any_succeeded = True - except TypeError: - continue - if any_succeeded: - return subs + raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") + + +@unify.register +def _( + typ: inspect.Signature, + subtyp: inspect.BoundArguments, + subs: Substitutions = {} +) -> Substitutions: + if typ != subtyp.signature: + raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}. ") + + subtyp_arguments = dict(subtyp.arguments) + for name, param in typ.parameters.items(): + if name in subtyp_arguments: + continue + elif param.kind is inspect.Parameter.VAR_POSITIONAL: + subtyp_arguments[name] = () + elif param.kind is inspect.Parameter.VAR_KEYWORD: + subtyp_arguments[name] = {} + elif param.default is not inspect.Parameter.empty: + subtyp_arguments[name] = nested_type(param.default) else: - raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") - elif typing.get_origin(subtyp) in {typing.Union, types.UnionType}: - for arg in typing.get_args(subtyp): - subs = unify(typ, arg, subs) + subtyp_arguments[name] = inspect.Parameter.empty + return unify(typ.parameters, subtyp_arguments, subs) + + +@unify.register +def _( + typ: inspect.Parameter, + subtyp: collections.abc.Sequence | collections.abc.Mapping | type | typing.ParamSpecArgs | typing.ParamSpecKwargs, + subs: Substitutions = {}, +) -> Substitutions: + if subtyp is inspect.Parameter.empty: return subs - elif isinstance(typ, typing._GenericAlias | types.GenericAlias) and \ - isinstance(subtyp, typing._GenericAlias | types.GenericAlias): + elif typ.kind is inspect.Parameter.VAR_POSITIONAL and isinstance(subtyp, collections.abc.Sequence): + return unify(tuple(typ.annotation for _ in subtyp), freshen(subtyp), subs) + elif typ.kind is inspect.Parameter.VAR_KEYWORD and isinstance(subtyp, collections.abc.Mapping): + return unify(tuple(typ.annotation for _ in subtyp), freshen(tuple(subtyp.values())), subs) + elif typ.kind not in {inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL} or \ + isinstance(subtyp, typing.ParamSpecArgs | typing.ParamSpecKwargs): + return unify(typ.annotation, freshen(subtyp), subs) + else: + raise TypeError(f"Cannot unify parameter {typ} with {subtyp} given {subs}. ") + + +@unify.register +def _( + typ: collections.abc.Mapping, + subtyp: collections.abc.Mapping[str, typing.Any], + subs: Substitutions = {}, +) -> Substitutions: + if set(typ.keys()) != set(subtyp.keys()): + raise TypeError(f"Cannot unify mapping type {typ} with {subtyp} given {subs}. ") + for k in typ.keys(): + subs = unify(typ[k], subtyp[k], subs) + return subs + + +@unify.register +def _( + typ: collections.abc.Sequence, + subtyp: collections.abc.Sequence, + subs: Substitutions = {}, +) -> Substitutions: + if len(typ) != len(subtyp): + raise TypeError(f"Cannot unify sequence type {typ} with {subtyp} given {subs}. ") + for p_item, c_item in zip(typ, subtyp): + subs = unify(p_item, c_item, subs) + return subs + + +@unify.register +def _( + typ: typing._AnnotatedAlias, # type: ignore + subtyp: type, + subs: Substitutions = {}, +) -> Substitutions: + return unify(typ.__origin__, subtyp, subs) + + +@unify.register +def _( + typ: types.UnionType, + subtyp: type, + subs: Substitutions = {}, +) -> Substitutions: + any_succeeded = False + for arg in typing.get_args(typ): + try: + subs = unify(arg, subtyp, subs) + any_succeeded = True + except TypeError: + continue + if any_succeeded: + return subs + else: + raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") + + +@unify.register +def _( + typ: types.GenericAlias | typing._GenericAlias, + subtyp: type | types.GenericAlias | typing.TypeVar | types.UnionType, + subs: Substitutions = {}, +) -> Substitutions: + if isinstance(subtyp, types.GenericAlias | typing._GenericAlias): subs = unify(typing.get_origin(typ), typing.get_origin(subtyp), subs) return unify(typing.get_args(typ), typing.get_args(subtyp), subs) - elif isinstance(typ, list | tuple) and isinstance(subtyp, list | tuple) and len(typ) == len(subtyp): - for p_item, c_item in zip(typ, subtyp): - subs = unify(p_item, c_item, subs) + else: + return unify.dispatch(type)(typ, subtyp, subs) + + +@unify.register +def _( + typ: type, + subtyp: type | typing.TypeVar | types.UnionType | types.GenericAlias | typing._GenericAlias, + subs: Substitutions = {}, +) -> Substitutions: + if isinstance(subtyp, typing.TypeVar): + return unify(subtyp, subs.get(subtyp, typ), {subtyp: typ, **subs}) + elif isinstance(subtyp, types.UnionType): + for arg in typing.get_args(subtyp): + subs = unify(typ, arg, subs) + return subs + elif isinstance(subtyp, types.GenericAlias | typing._GenericAlias) and \ + issubclass(typing.get_origin(subtyp), typ): return subs - elif issubclass(typing.get_origin(subtyp) or subtyp, typing.get_origin(typ) or typ): + elif isinstance(subtyp, type) and issubclass(subtyp, typing.get_origin(typ) or typ): return subs else: - raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") + raise TypeError(f"Cannot unify type {typ} with {subtyp} given {subs}. ") + + +@unify.register +def _( + typ: typing.TypeVar, + subtyp: type | typing.TypeVar | types.UnionType | types.GenericAlias, + subs: Substitutions = {}, +) -> Substitutions: + return subs if typ is subtyp else unify(subtyp, subs.get(typ, subtyp), {typ: subtyp, **subs}) + + +@unify.register +def _( + typ: typing.ParamSpecArgs, + subtyp: collections.abc.Sequence, + subs: Substitutions = {}, +) -> Substitutions: + return subs # {typ: subtyp, **subs} + + +@unify.register +def _( + typ: typing.ParamSpecKwargs, + subtyp: collections.abc.Mapping, + subs: Substitutions = {}, +) -> Substitutions: + return subs # {typ: subtyp, **subs} + + +@unify.register +def _( + typ: typing.ParamSpec, + subtyp: typing.ParamSpec | collections.abc.Sequence, + subs: Substitutions = {}, +) -> Substitutions: + return subs if typ is subtyp else {typ: subtyp, **subs} + + +@unify.register +def _( + typ: types.EllipsisType, + subtyp: types.EllipsisType | collections.abc.Sequence, + subs: Substitutions = {}, +) -> Substitutions: + return subs def freshen(tp: type | typing.TypeVar | types.GenericAlias | types.UnionType): @@ -278,153 +301,13 @@ def freshen(tp: type | typing.TypeVar | types.GenericAlias | types.UnionType): covariant=fv.__covariant__, contravariant=fv.__contravariant__, ) + if isinstance(fv, typing.TypeVar) + else typing.ParamSpec(name=f"{fv.__name__}_{random.randint(0, 1 << 32)}") for fv in freetypevars(tp) + if isinstance(fv, typing.TypeVar | typing.ParamSpec) }) -def canonicalize( - typ: type | typing.TypeVar | types.GenericAlias | types.UnionType, -) -> type: - """ - Return a canonical form of the given type expression. - - This function normalizes type expressions by: - - Removing Annotated wrappers to get the base type - - Converting legacy typing module aliases (e.g., typing.List) to modern forms (e.g., list) - - Preserving TypeVars unchanged - - Recursively canonicalizing type arguments in generic types - - Converting typing.Any to object - - Converting inspect.Parameter.empty to typing.Any (then to object) - - Handling Union types by creating canonical unions with | operator - - Converting non-type values to their types using _nested_type - - Args: - typ: The type expression to canonicalize. Can be a plain type, TypeVar, - generic alias, union type, or even a value that needs type inference. - - Returns: - A canonicalized version of the input type expression with consistent - representation and modern syntax. - - Examples: - >>> import typing - >>> import inspect - >>> import collections.abc - >>> from effectful.internals.unification import canonicalize - >>> T = typing.TypeVar('T') - >>> K = typing.TypeVar('K') - >>> V = typing.TypeVar('V') - - # Plain types pass through unchanged - >>> canonicalize(int) - - >>> canonicalize(str) - - - # TypeVars are preserved unchanged - >>> canonicalize(T) - ~T - - # Containers are normalized to abstract collections - >>> canonicalize(list[T]) - collections.abc.Sequence[~T] - - # Annotated types are unwrapped - >>> canonicalize(typing.Annotated[int, "metadata"]) - - >>> canonicalize(typing.Annotated[list[str], "doc string"]) - collections.abc.Sequence[str] - - # Nested generic types are recursively canonicalized - >>> canonicalize(typing.List[typing.Dict[K, V]]) - collections.abc.Sequence[collections.abc.Mapping[~K, ~V]] - >>> canonicalize(typing.Dict[str, typing.List[T]]) - collections.abc.Mapping[str, collections.abc.Sequence[~T]] - - # Union types are canonicalized with | operator - >>> result = canonicalize(typing.Union[int, str]) - >>> result == int | str - True - >>> result = canonicalize(typing.Union[list[T], dict[K, V]]) - >>> result == collections.abc.Sequence[T] | collections.abc.Mapping[K, V] - True - - # typing.Any becomes object - >>> canonicalize(typing.Any) - - - # inspect.Parameter.empty becomes object (via Any) - >>> canonicalize(inspect.Parameter.empty) - - """ - if typing.get_origin(typ) is typing.Annotated: - return canonicalize(typing.get_args(typ)[0]) - elif typ is inspect.Parameter.empty: - return canonicalize(typing.Any) - elif typ is None: - return type(None) - elif typ is Ellipsis: - return types.EllipsisType - elif typing.get_origin(typ) in {typing.Union, types.UnionType}: - t = canonicalize(typing.get_args(typ)[0]) - for arg in typing.get_args(typ)[1:]: - t = t | canonicalize(arg) - return t - elif isinstance(typ, typing.TypeVar): - return typ - elif isinstance(typ, typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs): - return typ - elif typing.get_origin(typ) is collections.abc.Callable: - origin, args = typing.get_origin(typ), typing.get_args(typ) - if not args: - return origin - elif len(args) == 2 and isinstance(args[0], (list, tuple)): - # Callable[[arg1, arg2, ...], return_type] format - param_list = [canonicalize(a) for a in args[0]] - return_type = canonicalize(args[1]) - return origin[[*param_list], return_type] - else: - # Handle other Callable formats - return origin[tuple(canonicalize(a) for a in args)] - elif typing.get_origin(typ) is typing.Literal: - t = type(typing.get_args(typ)[0]) - for arg in typing.get_args(typ)[1:]: - t = t | type(arg) - return canonicalize(t) - elif typing.get_origin(typ) is typing.Optional: - return canonicalize(None | typing.get_args(typ)[0]) - elif isinstance(typ, typing._GenericAlias | types.GenericAlias) and typing.get_origin(typ) is not typ: # type: ignore - # Handle generic types - origin = typing.get_origin(typ) - args = typing.get_args(typ) - # Regular generic types - return canonicalize(origin)[tuple(canonicalize(a) for a in args)] - # normalize built-in containers to abstract collections - elif typ is list: - return collections.abc.Sequence - elif typ is dict: - return collections.abc.Mapping - elif typ is set: - return collections.abc.Set - elif isinstance(typ, type) and issubclass(typ, range): - return collections.abc.Sequence[int] - # Handle legacy typing aliases - elif typ is typing.Tuple: - return canonicalize(tuple) - elif typ is typing.List: - return canonicalize(list) - elif typ is typing.Dict: - return canonicalize(dict) - elif typ is typing.Set: - return canonicalize(set) - elif typ is typing.Callable: - return collections.abc.Callable - elif typ is typing.Any: - return object - else: - return typ - - @functools.singledispatch def nested_type(value) -> type: """ @@ -606,6 +489,28 @@ def _(value: collections.abc.Set) -> type: return collections.abc.Set[nested_type(next(iter(value)))] +@nested_type.register +def _(value: list) -> type: + if len(value) == 0: + return list + return list[nested_type(next(iter(value)))] + + +@nested_type.register +def _(value: dict) -> type: + if len(value) == 0: + return dict + k, v = next(iter(value.items())) + return dict[nested_type(k), nested_type(v)] + + +@nested_type.register +def _(value: set) -> type: + if len(value) == 0: + return set + return set[nested_type(next(iter(value)))] + + @nested_type.register def _(value: collections.abc.Sequence) -> type: if len(value) == 0: @@ -632,8 +537,9 @@ def _(value: range) -> type: return type(value) +@functools.singledispatch def freetypevars( - typ: type | typing.TypeVar | types.GenericAlias | types.UnionType | types.NoneType | typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs, + typ: type | typing.TypeVar | types.GenericAlias | types.UnionType | types.NoneType | typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs | collections.abc.Sequence, ) -> set[typing.TypeVar | typing.ParamSpec]: """ Return a set of free type variables in the given type expression. @@ -700,10 +606,11 @@ def freetypevars( return set() +@functools.singledispatch def substitute( - typ: type | types.GenericAlias | types.UnionType, - subs: collections.abc.Mapping[typing.TypeVar, type | typing.ParamSpec | typing.TypeVar], -) -> type | types.GenericAlias | types.UnionType: + typ: type | types.GenericAlias | types.UnionType | typing.TypeVar | typing.ParamSpec | collections.abc.Sequence, + subs: Substitutions, +) -> type | types.GenericAlias | types.UnionType | typing.TypeVar | typing.ParamSpec | collections.abc.Sequence: """ Substitute type variables in a type expression with concrete types. @@ -748,7 +655,7 @@ def substitute( >>> substitute(int, {T: str}) """ - if isinstance(typ, typing.TypeVar): + if isinstance(typ, typing.TypeVar | typing.ParamSpec): return substitute(subs[typ], subs) if typ in subs else typ elif isinstance(typ, list | tuple): # Handle plain lists/sequences (e.g., in Callable's parameter list) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index c7b7fb27..45da5a20 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -5,9 +5,7 @@ import pytest from effectful.internals.unification import ( - canonicalize, freetypevars, - infer_return_type, nested_type, substitute, unify, @@ -399,13 +397,13 @@ def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported # Simple generic functions (identity, (int,), {}, int), (identity, (str,), {}, str), - (identity, (list[int],), {}, collections.abc.Sequence[int]), + (identity, (list[int],), {}, list[int]), # Multiple TypeVars (make_pair, (int, str), {}, tuple[int, str]), - (make_pair, (bool, list[float]), {}, tuple[bool, collections.abc.Sequence[float]]), + (make_pair, (bool, list[float]), {}, tuple[bool, list[float]]), # Generic collections (wrap_in_list, (int,), {}, list[int]), - (wrap_in_list, (dict[str, bool],), {}, list[collections.abc.Mapping[str, bool]]), + (wrap_in_list, (dict[str, bool],), {}, list[dict[str, bool]]), (get_first, (list[str],), {}, str), (get_first, (list[tuple[int, float]],), {}, tuple[int, float]), (getitem_mapping, (collections.abc.Mapping[str, int], str), {}, int), @@ -413,11 +411,11 @@ def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported getitem_mapping, (collections.abc.Mapping[bool, list[str]], bool), {}, - collections.abc.Sequence[str], + list[str], ), # Dict operations (dict_values, (dict[str, int],), {}, list[int]), - (dict_values, (dict[bool, list[str]],), {}, list[collections.abc.Sequence[str]]), + (dict_values, (dict[bool, list[str]],), {}, list[list[str]]), # Callable types (process_callable, (collections.abc.Callable[[int], str], int), {}, str), ( @@ -444,7 +442,7 @@ def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported multi_generic, (float, list[float], dict[bool, list[str]]), {}, - tuple[float, bool, collections.abc.Sequence[str]], + tuple[float, bool, list[str]], ), # Same TypeVar used multiple times (same_type_twice, (int, int), {}, int), @@ -481,7 +479,7 @@ def test_infer_return_type_success( ): sig = inspect.signature(func) bound = sig.bind(*args, **kwargs) - result = infer_return_type(bound) + result = substitute(sig.return_annotation, unify(sig, bound)) assert result == expected_return_type @@ -520,78 +518,7 @@ def test_infer_return_type_failure( sig = inspect.signature(func) bound = sig.bind(*args, **kwargs) with pytest.raises(TypeError): - infer_return_type(bound) - - -@pytest.mark.parametrize( - "typ,expected", - [ - # Plain types pass through unchanged - (int, int), - (str, str), - (float, float), - (bool, bool), - (tuple, tuple), - (list, collections.abc.Sequence), - (dict, collections.abc.Mapping), - (set, collections.abc.Set), - # TypeVars are preserved - (T, T), - (K, K), - (V, V), - # Legacy typing aliases are converted to modern forms - (list, collections.abc.Sequence), - (dict, collections.abc.Mapping), - (set, collections.abc.Set), - (list[int], collections.abc.Sequence[int]), - (dict[str, int], collections.abc.Mapping[str, int]), - (set[bool], collections.abc.Set[bool]), - (list[T], collections.abc.Sequence[T]), - (dict[K, V], collections.abc.Mapping[K, V]), - # Modern generic aliases pass through with canonicalized args - (list[list[int]], collections.abc.Sequence[collections.abc.Sequence[int]]), # Nested legacy conversion - (dict[str, dict[K, V]], collections.abc.Mapping[str, collections.abc.Mapping[K, V]]), - # typing.Callable becomes collections.abc.Callable - (typing.Callable, collections.abc.Callable), - (typing.Callable[[int], str], collections.abc.Callable[[int], str]), - (typing.Callable[[T, U], V], collections.abc.Callable[[T, U], V]), - # typing.Any becomes object - (typing.Any, object), - # inspect.Parameter.empty becomes object (via Any) - (inspect.Parameter.empty, object), - # Annotated types are unwrapped - (typing.Annotated[int, "metadata"], int), - (typing.Annotated[str, "doc string"], str), - (typing.Annotated[list[T], "annotation"], collections.abc.Sequence[T]), - (typing.Annotated[dict[K, V], "complex", "multi"], collections.abc.Mapping[K, V]), - # Nested Annotated unwrapping - (typing.Annotated[typing.Annotated[int, "inner"], "outer"], int), - # Union types are canonicalized with | operator - (typing.Union[int, str], int | str), - (typing.Union[T, int], T | int), - (typing.Union[list[T], dict[K, V]], collections.abc.Sequence[T] | collections.abc.Mapping[K, V]), - (typing.Union[int, str, bool], int | str | bool), - # Nested unions - (typing.Union[list[int], None], collections.abc.Sequence[int] | type(None)), - (typing.Union[dict[K, V], set[T]], collections.abc.Mapping[K, V] | collections.abc.Set[T]), - # Complex nested canonicalization - (list[dict[str, T] | None], collections.abc.Sequence[collections.abc.Mapping[str, T] | type(None)]), - (dict[K, list[V | None]], collections.abc.Mapping[K, collections.abc.Sequence[V | type(None)]]), - # None type handling - (type(None), type(None)), - (typing.Union[int, None], int | type(None)), - # Ellipsis type - (type(...), type(...)), - # Generic aliases with multiple levels of nesting - (list[dict[set[T], list[V]]], collections.abc.Sequence[collections.abc.Mapping[collections.abc.Set[T], collections.abc.Sequence[V]]]), - # Callable with nested canonicalization - (typing.Callable[[list[T]], dict[K, V]], collections.abc.Callable[[collections.abc.Sequence[T]], collections.abc.Mapping[K, V]]), - ], - ids=str, -) -def test_canonicalize(typ, expected): - result = canonicalize(typ) - assert result == expected + unify(sig, bound) @pytest.mark.parametrize( @@ -638,21 +565,21 @@ def test_canonicalize(typ, expected): (set(), set), ((), tuple), # Lists/sequences with single type - ([1, 2, 3], collections.abc.Sequence[int]), - ([1], collections.abc.Sequence[int]), - (["a", "b", "c"], collections.abc.Sequence[str]), - ([True, False], collections.abc.Sequence[bool]), - ([1.1, 2.2], collections.abc.Sequence[float]), + ([1, 2, 3], list[int]), + ([1], list[int]), + (["a", "b", "c"], list[str]), + ([True, False], list[bool]), + ([1.1, 2.2], list[float]), # Sets with elements - ({1, 2, 3}, collections.abc.Set[int]), - ({1}, collections.abc.Set[int]), - ({"a", "b"}, collections.abc.Set[str]), - ({True, False}, collections.abc.Set[bool]), + ({1, 2, 3}, set[int]), + ({1}, set[int]), + ({"a", "b"}, set[str]), + ({True, False}, set[bool]), # Dicts/mappings - ({"key": "value"}, collections.abc.Mapping[str, str]), - ({1: "one", 2: "two"}, collections.abc.Mapping[int, str]), - ({"a": 1, "b": 2}, collections.abc.Mapping[str, int]), - ({True: 1.0, False: 2.0}, collections.abc.Mapping[bool, float]), + ({"key": "value"}, dict[str, str]), + ({1: "one", 2: "two"}, dict[int, str]), + ({"a": 1, "b": 2}, dict[str, int]), + ({True: 1.0, False: 2.0}, dict[bool, float]), # Tuples preserve exact structure ((1, "hello", 3.14), tuple[int, str, float]), ((1,), tuple[int]), @@ -660,23 +587,23 @@ def test_canonicalize(typ, expected): (("a", "b", "c"), tuple[str, str, str]), ((True, 1, "x", 3.14), tuple[bool, int, str, float]), # Nested collections - ([[1, 2], [3, 4]], collections.abc.Sequence[collections.abc.Sequence[int]]), - ([{1, 2}, {3, 4}], collections.abc.Sequence[collections.abc.Set[int]]), - ([{"a": 1}, {"b": 2}], collections.abc.Sequence[collections.abc.Mapping[str, int]]), - ({"key": [1, 2, 3]}, collections.abc.Mapping[str, collections.abc.Sequence[int]]), - ({"a": {1, 2}, "b": {3, 4}}, collections.abc.Mapping[str, collections.abc.Set[int]]), - ({1: {"x": True}, 2: {"y": False}}, collections.abc.Mapping[int, collections.abc.Mapping[str, bool]]), + ([[1, 2], [3, 4]], list[list[int]]), + ([{1, 2}, {3, 4}], list[set[int]]), + ([{"a": 1}, {"b": 2}], list[dict[str, int]]), + ({"key": [1, 2, 3]}, dict[str, list[int]]), + ({"a": {1, 2}, "b": {3, 4}}, dict[str, set[int]]), + ({1: {"x": True}, 2: {"y": False}}, dict[int, dict[str, bool]]), # Tuples in collections - ([(1, "a"), (2, "b")], collections.abc.Sequence[tuple[int, str]]), - ({(1, 2), (3, 4)}, collections.abc.Set[tuple[int, int]]), - ({1: (True, "x"), 2: (False, "y")}, collections.abc.Mapping[int, tuple[bool, str]]), + ([(1, "a"), (2, "b")], list[tuple[int, str]]), + ({(1, 2), (3, 4)}, set[tuple[int, int]]), + ({1: (True, "x"), 2: (False, "y")}, dict[int, tuple[bool, str]]), # Functions/callables (lambda x: x, type(lambda x: x)), (print, type(print)), (len, type(len)), # Complex nested structures - ([[[1]]], collections.abc.Sequence[collections.abc.Sequence[collections.abc.Sequence[int]]]), - ({"a": {"b": {"c": 1}}}, collections.abc.Mapping[str, collections.abc.Mapping[str, collections.abc.Mapping[str, int]]]), + ([[[1]]], list[list[list[int]]]), + ({"a": {"b": {"c": 1}}}, dict[str, dict[str, dict[str, int]]]), # Special string/bytes handling (NOT treated as sequences) ("hello", str), (b"world", bytes), @@ -740,13 +667,3 @@ class ConcreteInterpretation(dict): # by verifying that our empty dict doesn't trigger the Interpretation path result = nested_type({}) assert result == dict # Empty dict should return dict, not Interpretation - - -def test_canonicalize_union_ordering(): - """Test union type canonicalization handles ordering consistently""" - # The | operator may reorder types, but the result should be equivalent - result1 = canonicalize(typing.Union[int, str]) - result2 = canonicalize(typing.Union[str, int]) - # Both should be some form of int | str or str | int - assert result1 == int | str or result1 == str | int - assert result2 == int | str or result2 == str | int From 2a9967097e570733ead34210c8748cf88db63343 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 9 Jul 2025 21:10:57 -0400 Subject: [PATCH 53/76] remove test ids --- tests/test_internals_unification.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 45da5a20..068ad53b 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -67,7 +67,6 @@ # ParamSpec and TypeVarTuple (if needed later) # (collections.abc.Callable[typing.ParamSpec("P"), T], {T}), # Would need to handle ParamSpec ], - ids=str, ) def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): assert freetypevars(typ) == fvs @@ -198,7 +197,6 @@ def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): collections.abc.Callable[[int], dict[str, V]], ), ], - ids=str, ) def test_substitute( typ: type, subs: typing.Mapping[typing.TypeVar, type], expected: type @@ -285,7 +283,6 @@ def test_substitute( {K: str, T: int, V: bool}, ), ], - ids=str, ) def test_unify_success( typ: type, @@ -319,7 +316,6 @@ def test_unify_success( ((T, V), (int,)), ([T, V], [int, str, bool]), ], - ids=str, ) def test_unify_failure( typ: type, @@ -469,7 +465,6 @@ def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported (variadic_kwargs_func, (), {"x": int}, int), (variadic_kwargs_func, (), {"x": int, "y": int}, int), ], - ids=str, ) def test_infer_return_type_success( func: collections.abc.Callable, @@ -508,7 +503,6 @@ def no_param_annotation(x) -> T: # No parameter annotation # Type mismatch - trying to unify incompatible types (same_type_twice, (int, str), {}), ], - ids=str, ) def test_infer_return_type_failure( func: collections.abc.Callable, @@ -611,7 +605,6 @@ def test_infer_return_type_failure( (range(5), type(range(5))), (slice(1, 10), type(slice(1, 10))), ], - ids=str, ) def test_nested_type(value, expected): result = nested_type(value) From 3e01b8ce946c64b3d0a50291d3c21efc70d2745e Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 9 Jul 2025 21:21:18 -0400 Subject: [PATCH 54/76] add to sphinx --- docs/source/effectful.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/effectful.rst b/docs/source/effectful.rst index 6d33542a..5ef25be8 100644 --- a/docs/source/effectful.rst +++ b/docs/source/effectful.rst @@ -98,3 +98,7 @@ Internals .. automodule:: effectful.internals.runtime :members: :undoc-members: + +.. automodule:: effectful.internals.unification + :members: + :undoc-members: From 13bbf74120a6cd8b8baf2ac209ad5384607907c6 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 9 Jul 2025 21:33:57 -0400 Subject: [PATCH 55/76] fix doctests --- effectful/internals/unification.py | 37 +++++++----------------------- 1 file changed, 8 insertions(+), 29 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 85f7dda4..394f3e82 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -55,18 +55,6 @@ def unify(typ, subtyp, subs: Substitutions = {}) -> Substitutions: >>> unify(list[T], list[int], {}) {~T: } - >>> # Multiple TypeVars - >>> unify(dict[K, V], dict[str, int], {}) - {~K: , ~V: } - - >>> # With existing substitutions - >>> unify(V, bool, {T: int}) - {~T: , ~V: } - - >>> # Nested generic unification - >>> unify(list[dict[K, V]], list[dict[str, int]], {}) - {~K: , ~V: } - >>> # Exact type matching >>> unify(int, int, {}) {} @@ -82,14 +70,6 @@ def unify(typ, subtyp, subs: Substitutions = {}) -> Substitutions: Traceback (most recent call last): ... TypeError: Cannot unify ... - - >>> # Callable type unification - >>> unify(collections.abc.Callable[[T], V], collections.abc.Callable[[int], str], {}) - {~T: , ~V: } - - >>> # Sequence unification (tuples as sequences) - >>> unify((T, V), (int, str), {}) - {~T: , ~V: } """ raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") @@ -297,7 +277,6 @@ def freshen(tp: type | typing.TypeVar | types.GenericAlias | types.UnionType): fv: typing.TypeVar( name=f"{fv.__name__}_{random.randint(0, 1 << 32)}", bound=fv.__bound__, - default=fv.__default__, covariant=fv.__covariant__, contravariant=fv.__contravariant__, ) @@ -372,9 +351,9 @@ def nested_type(value) -> type: # Sequences become Sequence[element_type] >>> nested_type([1, 2, 3]) - collections.abc.Sequence[int] + list[int] >>> nested_type(["a", "b", "c"]) - collections.abc.Sequence[str] + list[str] # Tuples preserve exact structure >>> nested_type((1, "hello", 3.14)) @@ -386,21 +365,21 @@ def nested_type(value) -> type: # Sets become Set[element_type] >>> nested_type({1, 2, 3}) - collections.abc.Set[int] + set[int] >>> nested_type({"a", "b"}) - collections.abc.Set[str] + set[str] # Mappings become Mapping[key_type, value_type] >>> nested_type({"key": "value"}) - collections.abc.Mapping[str, str] + dict[str, str] >>> nested_type({1: "one", 2: "two"}) - collections.abc.Mapping[int, str] + dict[int, str] # Nested collections work recursively >>> nested_type([{1: "one"}, {2: "two"}]) - collections.abc.Sequence[collections.abc.Mapping[int, str]] + list[dict[int, str]] >>> nested_type({"key": [1, 2, 3]}) - collections.abc.Mapping[str, collections.abc.Sequence[int]] + dict[str, list[int]] # Strings and bytes are NOT treated as sequences >>> nested_type("hello") From f655dce0affbaf3b3406588b6d2bafa7b2cea2cb Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 10:37:53 -0400 Subject: [PATCH 56/76] fix type checking --- effectful/internals/unification.py | 100 ++++++++++++++--------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 394f3e82..cd863fc8 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -5,9 +5,14 @@ import types import typing +if typing.TYPE_CHECKING: + GenericAlias = types.GenericAlias +else: + GenericAlias = types.GenericAlias | typing._GenericAlias + Substitutions = collections.abc.Mapping[ - typing.TypeVar | typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs, + typing.TypeVar | typing.ParamSpec, type | typing.TypeVar | typing.ParamSpec | collections.abc.Sequence | collections.abc.Mapping[str, typing.Any] ] @@ -173,11 +178,11 @@ def _( @unify.register def _( - typ: types.GenericAlias | typing._GenericAlias, + typ: GenericAlias, subtyp: type | types.GenericAlias | typing.TypeVar | types.UnionType, subs: Substitutions = {}, ) -> Substitutions: - if isinstance(subtyp, types.GenericAlias | typing._GenericAlias): + if isinstance(subtyp, GenericAlias): subs = unify(typing.get_origin(typ), typing.get_origin(subtyp), subs) return unify(typing.get_args(typ), typing.get_args(subtyp), subs) else: @@ -187,7 +192,7 @@ def _( @unify.register def _( typ: type, - subtyp: type | typing.TypeVar | types.UnionType | types.GenericAlias | typing._GenericAlias, + subtyp: type | typing.TypeVar | types.UnionType | GenericAlias, subs: Substitutions = {}, ) -> Substitutions: if isinstance(subtyp, typing.TypeVar): @@ -196,8 +201,7 @@ def _( for arg in typing.get_args(subtyp): subs = unify(typ, arg, subs) return subs - elif isinstance(subtyp, types.GenericAlias | typing._GenericAlias) and \ - issubclass(typing.get_origin(subtyp), typ): + elif isinstance(subtyp, GenericAlias) and issubclass(typing.get_origin(subtyp), typ): return subs elif isinstance(subtyp, type) and issubclass(subtyp, typing.get_origin(typ) or typ): return subs @@ -211,7 +215,7 @@ def _( subtyp: type | typing.TypeVar | types.UnionType | types.GenericAlias, subs: Substitutions = {}, ) -> Substitutions: - return subs if typ is subtyp else unify(subtyp, subs.get(typ, subtyp), {typ: subtyp, **subs}) + return subs if typ is subtyp else unify(subtyp, subs.get(typ, subtyp), {typ: subtyp, **subs}) # type: ignore @unify.register @@ -250,7 +254,7 @@ def _( return subs -def freshen(tp: type | typing.TypeVar | types.GenericAlias | types.UnionType): +def freshen(tp: typing.Any): """ Return a freshened version of the given type expression. @@ -422,34 +426,27 @@ def nested_type(value) -> type: @nested_type.register -def _(value: type | types.UnionType | types.GenericAlias | types.EllipsisType) -> type: +def _(value: type | types.UnionType | GenericAlias | types.EllipsisType): return value -@nested_type.register -def _(value: typing._GenericAlias) -> type: # type: ignore - # Handle typing module generic aliases - return value - - -@nested_type.register -def _(value: types.NoneType) -> type: - # Handle None specially +@nested_type.register(type(None)) +def _(value: None): return type(None) @nested_type.register -def _(value: typing.TypeVar) -> type: +def _(value: typing.TypeVar): raise TypeError(f"TypeVars should not appear in values, but got {value}") @nested_type.register -def _(value: collections.abc.Callable) -> type: +def _(value: collections.abc.Callable): return type(value) @nested_type.register -def _(value: collections.abc.Mapping) -> type: +def _(value: collections.abc.Mapping): from effectful.ops.types import Interpretation if type(value) is Interpretation: # More specific check @@ -458,67 +455,65 @@ def _(value: collections.abc.Mapping) -> type: return type(value) else: k, v = next(iter(value.items())) - return collections.abc.Mapping[nested_type(k), nested_type(v)] + return collections.abc.Mapping[nested_type(k), nested_type(v)] # type: ignore @nested_type.register -def _(value: collections.abc.Set) -> type: +def _(value: collections.abc.Set): if len(value) == 0: return type(value) - return collections.abc.Set[nested_type(next(iter(value)))] + return collections.abc.Set[nested_type(next(iter(value)))] # type: ignore @nested_type.register -def _(value: list) -> type: +def _(value: list): if len(value) == 0: return list - return list[nested_type(next(iter(value)))] + return list[nested_type(next(iter(value)))] # type: ignore @nested_type.register -def _(value: dict) -> type: +def _(value: dict): if len(value) == 0: return dict k, v = next(iter(value.items())) - return dict[nested_type(k), nested_type(v)] + return dict[nested_type(k), nested_type(v)] # type: ignore @nested_type.register -def _(value: set) -> type: +def _(value: set): if len(value) == 0: return set - return set[nested_type(next(iter(value)))] + return set[nested_type(next(iter(value)))] # type: ignore @nested_type.register -def _(value: collections.abc.Sequence) -> type: +def _(value: collections.abc.Sequence): if len(value) == 0: return type(value) - return collections.abc.Sequence[nested_type(next(iter(value)))] + return collections.abc.Sequence[nested_type(next(iter(value)))] # type: ignore @nested_type.register -def _(value: tuple) -> type: +def _(value: tuple): if len(value) == 0: return tuple - return tuple[tuple(nested_type(item) for item in value)] + return tuple[tuple(nested_type(item) for item in value)] # type: ignore @nested_type.register -def _(value: str | bytes) -> type: - # Handle str and bytes as their own types, not collections.abc.Sequence +def _(value: str | bytes): return type(value) @nested_type.register(range) -def _(value: range) -> type: - # Handle range as its own type, not as a sequence +def _(value: range): return type(value) @functools.singledispatch def freetypevars( - typ: type | typing.TypeVar | types.GenericAlias | types.UnionType | types.NoneType | typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs | collections.abc.Sequence, + typ: type | typing.TypeVar | types.GenericAlias | types.UnionType | None | typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs | collections.abc.Sequence, ) -> set[typing.TypeVar | typing.ParamSpec]: """ Return a set of free type variables in the given type expression. @@ -587,9 +582,9 @@ def freetypevars( @functools.singledispatch def substitute( - typ: type | types.GenericAlias | types.UnionType | typing.TypeVar | typing.ParamSpec | collections.abc.Sequence, + typ: type | types.GenericAlias | types.UnionType | None | typing.TypeVar | typing.ParamSpec | collections.abc.Sequence | collections.abc.Mapping, subs: Substitutions, -) -> type | types.GenericAlias | types.UnionType | typing.TypeVar | typing.ParamSpec | collections.abc.Sequence: +) -> type | types.GenericAlias | types.UnionType | None | typing.TypeVar | typing.ParamSpec | collections.abc.Sequence | collections.abc.Mapping: """ Substitute type variables in a type expression with concrete types. @@ -636,16 +631,21 @@ def substitute( """ if isinstance(typ, typing.TypeVar | typing.ParamSpec): return substitute(subs[typ], subs) if typ in subs else typ + elif isinstance(typ, typing.ParamSpecArgs): + return substitute(typing.get_origin(typ), subs).args + elif isinstance(typ, typing.ParamSpecKwargs): + return substitute(typing.get_origin(typ), subs).kwargs elif isinstance(typ, list | tuple): - # Handle plain lists/sequences (e.g., in Callable's parameter list) return type(typ)(substitute(item, subs) for item in typ) - elif typing.get_args(typ): - origin = typing.get_origin(typ) - assert origin is not None, "Type must have an origin" - new_args = tuple(substitute(arg, subs) for arg in typing.get_args(typ)) - # Handle Union types specially - if origin is types.UnionType: - return typing.Union[new_args] # noqa - return origin[new_args] + elif isinstance(typ, collections.abc.Mapping): + return {k: substitute(v, subs) for k, v in typ.items()} + elif isinstance(typ, GenericAlias) and typing.get_args(typ): + return substitute(typing.get_origin(typ), subs)[substitute(typing.get_args(typ), subs)] # type: ignore + elif isinstance(typ, types.UnionType): + ts: tuple = substitute(typing.get_args(typ), subs) # type: ignore + tp, ts = ts[0], ts[1:] + for arg in ts: + tp = tp | arg + return tp else: return typ From bc371ee66a5f661074670f63f036a88395c271e0 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 10:39:29 -0400 Subject: [PATCH 57/76] str --- effectful/internals/unification.py | 1 + 1 file changed, 1 insertion(+) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index cd863fc8..d11e78d8 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -638,6 +638,7 @@ def substitute( elif isinstance(typ, list | tuple): return type(typ)(substitute(item, subs) for item in typ) elif isinstance(typ, collections.abc.Mapping): + assert all(isinstance(k, str) for k in typ.keys()), "Mapping keys must be strings" return {k: substitute(v, subs) for k, v in typ.items()} elif isinstance(typ, GenericAlias) and typing.get_args(typ): return substitute(typing.get_origin(typ), subs)[substitute(typing.get_args(typ), subs)] # type: ignore From 2b5db428a06084ec50e8e2b98fbbc187ef6318eb Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 11:00:00 -0400 Subject: [PATCH 58/76] lint and format --- effectful/internals/unification.py | 144 ++++++++++++++++++++-------- tests/test_internals_unification.py | 55 ++++------- 2 files changed, 119 insertions(+), 80 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index d11e78d8..73dcea00 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -13,7 +13,11 @@ Substitutions = collections.abc.Mapping[ typing.TypeVar | typing.ParamSpec, - type | typing.TypeVar | typing.ParamSpec | collections.abc.Sequence | collections.abc.Mapping[str, typing.Any] + type + | typing.TypeVar + | typing.ParamSpec + | collections.abc.Sequence + | collections.abc.Mapping[str, typing.Any], ] @@ -81,9 +85,7 @@ def unify(typ, subtyp, subs: Substitutions = {}) -> Substitutions: @unify.register def _( - typ: inspect.Signature, - subtyp: inspect.BoundArguments, - subs: Substitutions = {} + typ: inspect.Signature, subtyp: inspect.BoundArguments, subs: Substitutions = {} ) -> Substitutions: if typ != subtyp.signature: raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}. ") @@ -106,17 +108,29 @@ def _( @unify.register def _( typ: inspect.Parameter, - subtyp: collections.abc.Sequence | collections.abc.Mapping | type | typing.ParamSpecArgs | typing.ParamSpecKwargs, + subtyp: collections.abc.Sequence + | collections.abc.Mapping + | type + | typing.ParamSpecArgs + | typing.ParamSpecKwargs, subs: Substitutions = {}, ) -> Substitutions: if subtyp is inspect.Parameter.empty: return subs - elif typ.kind is inspect.Parameter.VAR_POSITIONAL and isinstance(subtyp, collections.abc.Sequence): + elif typ.kind is inspect.Parameter.VAR_POSITIONAL and isinstance( + subtyp, collections.abc.Sequence + ): return unify(tuple(typ.annotation for _ in subtyp), freshen(subtyp), subs) - elif typ.kind is inspect.Parameter.VAR_KEYWORD and isinstance(subtyp, collections.abc.Mapping): - return unify(tuple(typ.annotation for _ in subtyp), freshen(tuple(subtyp.values())), subs) - elif typ.kind not in {inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL} or \ - isinstance(subtyp, typing.ParamSpecArgs | typing.ParamSpecKwargs): + elif typ.kind is inspect.Parameter.VAR_KEYWORD and isinstance( + subtyp, collections.abc.Mapping + ): + return unify( + tuple(typ.annotation for _ in subtyp), freshen(tuple(subtyp.values())), subs + ) + elif typ.kind not in { + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.VAR_POSITIONAL, + } or isinstance(subtyp, typing.ParamSpecArgs | typing.ParamSpecKwargs): return unify(typ.annotation, freshen(subtyp), subs) else: raise TypeError(f"Cannot unify parameter {typ} with {subtyp} given {subs}. ") @@ -142,7 +156,9 @@ def _( subs: Substitutions = {}, ) -> Substitutions: if len(typ) != len(subtyp): - raise TypeError(f"Cannot unify sequence type {typ} with {subtyp} given {subs}. ") + raise TypeError( + f"Cannot unify sequence type {typ} with {subtyp} given {subs}. " + ) for p_item, c_item in zip(typ, subtyp): subs = unify(p_item, c_item, subs) return subs @@ -154,7 +170,7 @@ def _( subtyp: type, subs: Substitutions = {}, ) -> Substitutions: - return unify(typ.__origin__, subtyp, subs) + return unify(typing.get_args(typ)[0], subtyp, subs) @unify.register @@ -168,7 +184,7 @@ def _( try: subs = unify(arg, subtyp, subs) any_succeeded = True - except TypeError: + except TypeError: # noqa continue if any_succeeded: return subs @@ -201,7 +217,11 @@ def _( for arg in typing.get_args(subtyp): subs = unify(typ, arg, subs) return subs - elif isinstance(subtyp, GenericAlias) and issubclass(typing.get_origin(subtyp), typ): + elif isinstance(subtyp, typing._AnnotatedAlias): # type: ignore + return unify(typ, typing.get_args(subtyp)[0], subs) + elif isinstance(subtyp, GenericAlias) and issubclass( + typing.get_origin(subtyp), typ + ): return subs elif isinstance(subtyp, type) and issubclass(subtyp, typing.get_origin(typ) or typ): return subs @@ -215,7 +235,11 @@ def _( subtyp: type | typing.TypeVar | types.UnionType | types.GenericAlias, subs: Substitutions = {}, ) -> Substitutions: - return subs if typ is subtyp else unify(subtyp, subs.get(typ, subtyp), {typ: subtyp, **subs}) # type: ignore + return ( + subs + if typ is subtyp + else unify(subtyp, subs.get(typ, subtyp), {typ: subtyp, **subs}) # type: ignore + ) @unify.register @@ -277,22 +301,27 @@ def freshen(tp: typing.Any): >>> freshen(T) == T False """ - return substitute(tp, { - fv: typing.TypeVar( - name=f"{fv.__name__}_{random.randint(0, 1 << 32)}", - bound=fv.__bound__, - covariant=fv.__covariant__, - contravariant=fv.__contravariant__, - ) - if isinstance(fv, typing.TypeVar) - else typing.ParamSpec(name=f"{fv.__name__}_{random.randint(0, 1 << 32)}") - for fv in freetypevars(tp) - if isinstance(fv, typing.TypeVar | typing.ParamSpec) - }) + return substitute( + tp, + { + fv: typing.TypeVar( + name=f"{fv.__name__}_{random.randint(0, 1 << 32)}", + bound=fv.__bound__, + covariant=fv.__covariant__, + contravariant=fv.__contravariant__, + ) + if isinstance(fv, typing.TypeVar) + else typing.ParamSpec(name=f"{fv.__name__}_{random.randint(0, 1 << 32)}") + for fv in freetypevars(tp) + if isinstance(fv, typing.TypeVar | typing.ParamSpec) + }, + ) @functools.singledispatch -def nested_type(value) -> type: +def nested_type( + value, +) -> type | GenericAlias | types.UnionType | types.EllipsisType | None: """ Infer the type of a value, handling nested collections with generic parameters. @@ -513,8 +542,16 @@ def _(value: range): @functools.singledispatch def freetypevars( - typ: type | typing.TypeVar | types.GenericAlias | types.UnionType | None | typing.ParamSpec | typing.ParamSpecArgs | typing.ParamSpecKwargs | collections.abc.Sequence, -) -> set[typing.TypeVar | typing.ParamSpec]: + typ: type + | typing.TypeVar + | types.GenericAlias + | types.UnionType + | None + | typing.ParamSpec + | typing.ParamSpecArgs + | typing.ParamSpecKwargs + | collections.abc.Sequence, +) -> collections.abc.Set[typing.TypeVar | typing.ParamSpec]: """ Return a set of free type variables in the given type expression. @@ -568,23 +605,42 @@ def freetypevars( return {typ} elif isinstance(typ, typing.ParamSpecArgs | typing.ParamSpecKwargs): return freetypevars(typing.get_origin(typ)) - elif typing.get_origin(typ) is typing.Annotated: + elif isinstance(typ, typing._AnnotatedAlias): # type: ignore return freetypevars(typing.get_args(typ)[0]) - elif isinstance(typ, list | tuple): - # Handle plain lists and tuples (not generic aliases) - return set.union(*(freetypevars(item) for item in typ)) if typ else set() - elif typing.get_args(typ): - # Handle generic aliases - return set.union(*(freetypevars(arg) for arg in typing.get_args(typ))) + elif isinstance(typ, collections.abc.Sequence): + return set().union(*(freetypevars(item) for item in typ)) + elif isinstance(typ, collections.abc.Mapping): + assert all(isinstance(k, str) for k in typ.keys()), ( + "Mapping keys must be strings" + ) + return freetypevars(typ.values()) + elif isinstance(typ, GenericAlias | types.UnionType): + return freetypevars(typing.get_args(typ)) else: - return set() + return freetypevars(typing.get_args(typ)) @functools.singledispatch def substitute( - typ: type | types.GenericAlias | types.UnionType | None | typing.TypeVar | typing.ParamSpec | collections.abc.Sequence | collections.abc.Mapping, + typ: type + | types.GenericAlias + | types.UnionType + | None + | typing.TypeVar + | typing.ParamSpec + | collections.abc.Sequence + | collections.abc.Mapping, subs: Substitutions, -) -> type | types.GenericAlias | types.UnionType | None | typing.TypeVar | typing.ParamSpec | collections.abc.Sequence | collections.abc.Mapping: +) -> ( + type + | types.GenericAlias + | types.UnionType + | None + | typing.TypeVar + | typing.ParamSpec + | collections.abc.Sequence + | collections.abc.Mapping +): """ Substitute type variables in a type expression with concrete types. @@ -638,10 +694,14 @@ def substitute( elif isinstance(typ, list | tuple): return type(typ)(substitute(item, subs) for item in typ) elif isinstance(typ, collections.abc.Mapping): - assert all(isinstance(k, str) for k in typ.keys()), "Mapping keys must be strings" + assert all(isinstance(k, str) for k in typ.keys()), ( + "Mapping keys must be strings" + ) return {k: substitute(v, subs) for k, v in typ.items()} elif isinstance(typ, GenericAlias) and typing.get_args(typ): - return substitute(typing.get_origin(typ), subs)[substitute(typing.get_args(typ), subs)] # type: ignore + return substitute(typing.get_origin(typ), subs)[ + substitute(typing.get_args(typ), subs) + ] # type: ignore elif isinstance(typ, types.UnionType): ts: tuple = substitute(typing.get_args(typ), subs) # type: ignore tp, ts = ts[0], ts[1:] diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 068ad53b..a4673666 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -11,10 +11,16 @@ unify, ) -T = typing.TypeVar("T") -K = typing.TypeVar("K") -V = typing.TypeVar("V") -U = typing.TypeVar("U") +if typing.TYPE_CHECKING: + T = typing.Any + K = typing.Any + V = typing.Any + U = typing.Any +else: + T = typing.TypeVar("T") + K = typing.TypeVar("K") + V = typing.TypeVar("V") + U = typing.TypeVar("U") @pytest.mark.parametrize( @@ -201,7 +207,7 @@ def test_freetypevars(typ: type, fvs: set[typing.TypeVar]): def test_substitute( typ: type, subs: typing.Mapping[typing.TypeVar, type], expected: type ): - assert substitute(typ, subs) == expected + assert substitute(typ, subs) == expected # type: ignore @pytest.mark.parametrize( @@ -290,7 +296,7 @@ def test_unify_success( initial_subs: typing.Mapping[typing.TypeVar, type], expected_subs: typing.Mapping[typing.TypeVar, type], ): - assert unify(typ, subtyp, initial_subs) == expected_subs + assert unify(typ, subtyp, initial_subs) == expected_subs # type: ignore @pytest.mark.parametrize( @@ -480,7 +486,7 @@ def test_infer_return_type_success( # Error cases def unbound_typevar_func(x: T) -> tuple[T, V]: # V not in parameters - return (x, "error") # type: ignore + return (x, "error") def no_return_annotation(x: T): # No return annotation @@ -488,7 +494,7 @@ def no_return_annotation(x: T): # No return annotation def no_param_annotation(x) -> T: # No parameter annotation - return x # type: ignore + return x @pytest.mark.parametrize( @@ -551,7 +557,6 @@ def test_infer_return_type_failure( (list[T], list[T]), (dict[K, V], dict[K, V]), # Union types pass through - (typing.Union[int, str], typing.Union[int, str]), (int | str, int | str), # Empty collections ([], list), @@ -615,10 +620,10 @@ def test_nested_type_typevar_error(): """Test that TypeVars raise TypeError in nested_type""" with pytest.raises(TypeError, match="TypeVars should not appear in values"): nested_type(T) - + with pytest.raises(TypeError, match="TypeVars should not appear in values"): nested_type(K) - + with pytest.raises(TypeError, match="TypeVars should not appear in values"): nested_type(V) @@ -630,33 +635,7 @@ def test_nested_type_term_error(): from unittest.mock import Mock from effectful.ops.types import Term - + mock_term = Mock(spec=Term) with pytest.raises(TypeError, match="Terms should not appear in nested_type"): nested_type(mock_term) - - -def test_nested_type_interpretation_special_case(): - """Test that Interpretation type is handled specially in mapping dispatch""" - # This tests the special case in the Mapping dispatch for Interpretation - from effectful.ops.types import Interpretation - - # Since Interpretation is a Protocol, we can't instantiate it directly. - # Instead, create a concrete implementation - class ConcreteInterpretation(dict): - """A concrete implementation of Interpretation for testing""" - pass - - # Make it look like an Interpretation type for the type check - ConcreteInterpretation.__name__ = 'Interpretation' - ConcreteInterpretation.__module__ = Interpretation.__module__ - - # Create instance and patch the type temporarily - interp = ConcreteInterpretation() - original_type = type(interp) - - # The _nested_type checks `type(value) is Interpretation` - # Since we can't change the type of an instance, we'll test the behavior - # by verifying that our empty dict doesn't trigger the Interpretation path - result = nested_type({}) - assert result == dict # Empty dict should return dict, not Interpretation From 2dddc92eb5a9cfb11f8c298eb1942e2a3446a539 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 12:05:28 -0400 Subject: [PATCH 59/76] freshen -> _freshen --- effectful/internals/unification.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 73dcea00..42f5f94e 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -120,18 +120,18 @@ def _( elif typ.kind is inspect.Parameter.VAR_POSITIONAL and isinstance( subtyp, collections.abc.Sequence ): - return unify(tuple(typ.annotation for _ in subtyp), freshen(subtyp), subs) + return unify(tuple(typ.annotation for _ in subtyp), _freshen(subtyp), subs) elif typ.kind is inspect.Parameter.VAR_KEYWORD and isinstance( subtyp, collections.abc.Mapping ): return unify( - tuple(typ.annotation for _ in subtyp), freshen(tuple(subtyp.values())), subs + tuple(typ.annotation for _ in subtyp), _freshen(tuple(subtyp.values())), subs ) elif typ.kind not in { inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL, } or isinstance(subtyp, typing.ParamSpecArgs | typing.ParamSpecKwargs): - return unify(typ.annotation, freshen(subtyp), subs) + return unify(typ.annotation, _freshen(subtyp), subs) else: raise TypeError(f"Cannot unify parameter {typ} with {subtyp} given {subs}. ") @@ -278,7 +278,7 @@ def _( return subs -def freshen(tp: typing.Any): +def _freshen(tp: typing.Any): """ Return a freshened version of the given type expression. From 2199d631555cebffc5c50d2ba12fcdd02e63d1ed Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 12:23:06 -0400 Subject: [PATCH 60/76] factor out freetypevars and susbtitute --- effectful/internals/unification.py | 228 +++++++++++++++++++++-------- 1 file changed, 170 insertions(+), 58 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 42f5f94e..dfb3861f 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -1,3 +1,61 @@ +"""Type unification and inference utilities for Python's generic type system. + +This module implements a unification algorithm for type inference over a subset of +Python's generic types. Unification is a fundamental operation in type systems that +finds substitutions for type variables to make two types equivalent. + +The module provides four main operations: + +1. **unify(typ, subtyp, subs={})**: The core unification algorithm that attempts to + find a substitution mapping for type variables that makes a pattern type equal to + a concrete type. It handles TypeVars, generic types (List[T], Dict[K,V]), unions, + callables, and function signatures with inspect.Signature/BoundArguments. + +2. **substitute(typ, subs)**: Applies a substitution mapping to a type expression, + replacing all TypeVars with their mapped concrete types. This is used to + instantiate generic types after unification. + +3. **freetypevars(typ)**: Extracts all free (unbound) type variables from a type + expression. Useful for analyzing generic types and ensuring all TypeVars are + properly bound. + +4. **nested_type(value)**: Infers the type of a runtime value, handling nested + collections by recursively determining element types. For example, [1, 2, 3] + becomes list[int], and {"key": [1, 2]} becomes dict[str, list[int]]. + +The unification algorithm uses a single-dispatch pattern to handle different type +combinations: +- TypeVar unification binds variables to concrete types +- Generic type unification matches origins and recursively unifies type arguments +- Structural unification handles sequences and mappings by element +- Union types attempt unification with any matching branch +- Function signatures unify parameter types with bound arguments + +Example usage: + >>> from effectful.internals.unification import unify, substitute, freetypevars + >>> import typing + >>> T = typing.TypeVar('T') + >>> K = typing.TypeVar('K') + >>> V = typing.TypeVar('V') + + >>> # Find substitution that makes list[T] equal to list[int] + >>> subs = unify(list[T], list[int]) + >>> subs + {~T: } + + >>> # Apply substitution to instantiate a generic type + >>> substitute(dict[K, list[V]], {K: str, V: int}) + dict[str, list[int]] + + >>> # Find all type variables in a type expression + >>> freetypevars(dict[K, list[V]]) + {~K, ~V} + +This module is primarily used internally by effectful for type inference in its +effect system, allowing it to track and propagate type information through +effect handlers and operations. +""" + import collections.abc import functools import inspect @@ -125,7 +183,9 @@ def _( subtyp, collections.abc.Mapping ): return unify( - tuple(typ.annotation for _ in subtyp), _freshen(tuple(subtyp.values())), subs + tuple(typ.annotation for _ in subtyp), + _freshen(tuple(subtyp.values())), + subs, ) elif typ.kind not in { inspect.Parameter.VAR_KEYWORD, @@ -541,17 +601,7 @@ def _(value: range): @functools.singledispatch -def freetypevars( - typ: type - | typing.TypeVar - | types.GenericAlias - | types.UnionType - | None - | typing.ParamSpec - | typing.ParamSpecArgs - | typing.ParamSpecKwargs - | collections.abc.Sequence, -) -> collections.abc.Set[typing.TypeVar | typing.ParamSpec]: +def freetypevars(typ) -> collections.abc.Set[typing.TypeVar | typing.ParamSpec]: """ Return a set of free type variables in the given type expression. @@ -601,36 +651,64 @@ def freetypevars( >>> freetypevars(dict[str, T]) {~T} """ - if isinstance(typ, typing.TypeVar | typing.ParamSpec): - return {typ} - elif isinstance(typ, typing.ParamSpecArgs | typing.ParamSpecKwargs): - return freetypevars(typing.get_origin(typ)) - elif isinstance(typ, typing._AnnotatedAlias): # type: ignore - return freetypevars(typing.get_args(typ)[0]) - elif isinstance(typ, collections.abc.Sequence): - return set().union(*(freetypevars(item) for item in typ)) - elif isinstance(typ, collections.abc.Mapping): - assert all(isinstance(k, str) for k in typ.keys()), ( - "Mapping keys must be strings" - ) - return freetypevars(typ.values()) - elif isinstance(typ, GenericAlias | types.UnionType): - return freetypevars(typing.get_args(typ)) - else: - return freetypevars(typing.get_args(typ)) + # Default case for plain types + return freetypevars(typing.get_args(typ)) + + +@freetypevars.register +def _(typ: typing.TypeVar): + return {typ} + + +@freetypevars.register +def _(typ: typing.ParamSpec): + return {typ} + + +@freetypevars.register +def _(typ: typing.ParamSpecArgs): + return freetypevars(typing.get_origin(typ)) + + +@freetypevars.register +def _(typ: typing.ParamSpecKwargs): + return freetypevars(typing.get_origin(typ)) + + +@freetypevars.register +def _(typ: typing._AnnotatedAlias): # type: ignore + return freetypevars(typing.get_args(typ)[0]) + + +@freetypevars.register +def _(typ: collections.abc.Sequence): + return set().union(*(freetypevars(item) for item in typ)) + + +@freetypevars.register +def _(typ: str | bytes): + return set() + + +@freetypevars.register +def _(typ: collections.abc.Mapping): + assert all(isinstance(k, str) for k in typ.keys()), "Mapping keys must be strings" + return freetypevars(typ.values()) + + +@freetypevars.register +def _(typ: GenericAlias): + return freetypevars(typing.get_args(typ)) + + +@freetypevars.register +def _(typ: types.UnionType): + return freetypevars(typing.get_args(typ)) @functools.singledispatch def substitute( - typ: type - | types.GenericAlias - | types.UnionType - | None - | typing.TypeVar - | typing.ParamSpec - | collections.abc.Sequence - | collections.abc.Mapping, - subs: Substitutions, + typ, subs: Substitutions ) -> ( type | types.GenericAlias @@ -685,28 +763,62 @@ def substitute( >>> substitute(int, {T: str}) """ - if isinstance(typ, typing.TypeVar | typing.ParamSpec): - return substitute(subs[typ], subs) if typ in subs else typ - elif isinstance(typ, typing.ParamSpecArgs): - return substitute(typing.get_origin(typ), subs).args - elif isinstance(typ, typing.ParamSpecKwargs): - return substitute(typing.get_origin(typ), subs).kwargs - elif isinstance(typ, list | tuple): - return type(typ)(substitute(item, subs) for item in typ) - elif isinstance(typ, collections.abc.Mapping): - assert all(isinstance(k, str) for k in typ.keys()), ( - "Mapping keys must be strings" - ) - return {k: substitute(v, subs) for k, v in typ.items()} - elif isinstance(typ, GenericAlias) and typing.get_args(typ): + # Default case for plain types + return typ + + +@substitute.register +def _(typ: typing.TypeVar, subs: Substitutions): + return substitute(subs[typ], subs) if typ in subs else typ + + +@substitute.register +def _(typ: typing.ParamSpec, subs: Substitutions): + return substitute(subs[typ], subs) if typ in subs else typ + + +@substitute.register +def _(typ: typing.ParamSpecArgs, subs: Substitutions): + res = substitute(typing.get_origin(typ), subs) + return res.args if isinstance(res, typing.ParamSpec) else res + + +@substitute.register +def _(typ: typing.ParamSpecKwargs, subs: Substitutions): + res = substitute(typing.get_origin(typ), subs) + return res.kwargs if isinstance(res, typing.ParamSpec) else res + + +@substitute.register +def _(typ: list, subs: Substitutions): + return list(substitute(item, subs) for item in typ) + + +@substitute.register +def _(typ: tuple, subs: Substitutions): + return tuple(substitute(item, subs) for item in typ) + + +@substitute.register +def _(typ: collections.abc.Mapping, subs: Substitutions): + assert all(isinstance(k, str) for k in typ.keys()), "Mapping keys must be strings" + return {k: substitute(v, subs) for k, v in typ.items()} + + +@substitute.register +def _(typ: GenericAlias, subs: Substitutions): + if typing.get_args(typ): return substitute(typing.get_origin(typ), subs)[ substitute(typing.get_args(typ), subs) ] # type: ignore - elif isinstance(typ, types.UnionType): - ts: tuple = substitute(typing.get_args(typ), subs) # type: ignore - tp, ts = ts[0], ts[1:] - for arg in ts: - tp = tp | arg - return tp else: return typ + + +@substitute.register +def _(typ: types.UnionType, subs: Substitutions): + ts: tuple = substitute(typing.get_args(typ), subs) # type: ignore + tp, ts = ts[0], ts[1:] + for arg in ts: + tp = tp | arg + return tp From c6fe3fc764d82fb92943af1b41b0415bb41dd1aa Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 13:02:53 -0400 Subject: [PATCH 61/76] truncate names --- effectful/internals/unification.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index dfb3861f..a96afa26 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -365,13 +365,15 @@ def _freshen(tp: typing.Any): tp, { fv: typing.TypeVar( - name=f"{fv.__name__}_{random.randint(0, 1 << 32)}", + name=f"{fv.__name__[:100]}_{random.randint(0, 1 << 32)}", bound=fv.__bound__, covariant=fv.__covariant__, contravariant=fv.__contravariant__, ) if isinstance(fv, typing.TypeVar) - else typing.ParamSpec(name=f"{fv.__name__}_{random.randint(0, 1 << 32)}") + else typing.ParamSpec( + name=f"{fv.__name__[:100]}_{random.randint(0, 1 << 32)}" + ) for fv in freetypevars(tp) if isinstance(fv, typing.TypeVar | typing.ParamSpec) }, From 97f486275b61970b9922e5b85eebe6e43ac82120 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 13:04:06 -0400 Subject: [PATCH 62/76] doctest --- effectful/internals/unification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index a96afa26..2aab5473 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -356,9 +356,9 @@ def _freshen(tp: typing.Any): Examples: >>> import typing >>> T = typing.TypeVar('T') - >>> isinstance(freshen(T), typing.TypeVar) + >>> isinstance(_freshen(T), typing.TypeVar) True - >>> freshen(T) == T + >>> _freshen(T) == T False """ return substitute( From 23de369de2284fd9e0f706511b66cdd3b0e5fa7e Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 13:04:53 -0400 Subject: [PATCH 63/76] doctest --- effectful/internals/unification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 2aab5473..cff2f942 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -48,8 +48,8 @@ dict[str, list[int]] >>> # Find all type variables in a type expression - >>> freetypevars(dict[K, list[V]]) - {~K, ~V} + >>> freetypevars(dict[str, list[V]]) + {~V} This module is primarily used internally by effectful for type inference in its effect system, allowing it to track and propagate type information through From f8247da2fbcb09926df236054d3b7f653edbb045 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 14:17:55 -0400 Subject: [PATCH 64/76] add compositional tests --- tests/test_internals_unification.py | 548 ++++++++++++++++++++++++++++ 1 file changed, 548 insertions(+) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index a4673666..d113ec1a 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -639,3 +639,551 @@ def test_nested_type_term_error(): mock_term = Mock(spec=Term) with pytest.raises(TypeError, match="Terms should not appear in nested_type"): nested_type(mock_term) + + +def sequence_getitem(seq: collections.abc.Sequence[T], index: int) -> T: + return seq[index] + + +def mapping_getitem(mapping: collections.abc.Mapping[K, V], key: K) -> V: + return mapping[key] + + +def sequence_mapping_getitem( + seq: collections.abc.Sequence[collections.abc.Mapping[K, V]], index: int, key: K +) -> V: + return mapping_getitem(sequence_getitem(seq, index), key) + + +def mapping_sequence_getitem( + mapping: collections.abc.Mapping[K, collections.abc.Sequence[T]], key: K, index: int +) -> T: + return sequence_getitem(mapping_getitem(mapping, key), index) + + +def sequence_from_pair(a: T, b: T) -> collections.abc.Sequence[T]: + return [a, b] + + +def mapping_from_pair(a: K, b: V) -> collections.abc.Mapping[K, V]: + return {a: b} + + +def sequence_of_mappings( + key1: K, val1: V, key2: K, val2: V +) -> collections.abc.Sequence[collections.abc.Mapping[K, V]]: + """Creates a sequence containing two mappings.""" + return sequence_from_pair( + mapping_from_pair(key1, val1), mapping_from_pair(key2, val2) + ) + + +def mapping_of_sequences( + key1: K, val1: T, val2: T, key2: K, val3: T, val4: T +) -> collections.abc.Mapping[K, collections.abc.Sequence[T]]: + """Creates a mapping where each key maps to a sequence of two values.""" + return mapping_from_pair(key1, sequence_from_pair(val1, val2)) + + +def nested_sequence_mapping( + k1: K, v1: T, v2: T, k2: K, v3: T, v4: T +) -> collections.abc.Sequence[collections.abc.Mapping[K, collections.abc.Sequence[T]]]: + """Creates a sequence of mappings, where each mapping contains sequences.""" + return sequence_from_pair( + mapping_from_pair(k1, sequence_from_pair(v1, v2)), + mapping_from_pair(k2, sequence_from_pair(v3, v4)), + ) + + +def get_from_constructed_sequence(a: T, b: T, index: int) -> T: + """Constructs a sequence from two elements and gets one by index.""" + return sequence_getitem(sequence_from_pair(a, b), index) + + +def get_from_constructed_mapping(key: K, value: V, lookup_key: K) -> V: + """Constructs a mapping from a key-value pair and looks up the value.""" + return mapping_getitem(mapping_from_pair(key, value), lookup_key) + + +def double_nested_get( + k1: K, + v1: T, + v2: T, + k2: K, + v3: T, + v4: T, + outer_index: int, + inner_key: K, + inner_index: int, +) -> T: + """Creates nested structure and retrieves deeply nested value.""" + nested = nested_sequence_mapping(k1, v1, v2, k2, v3, v4) + mapping = sequence_getitem(nested, outer_index) + sequence = mapping_getitem(mapping, inner_key) + return sequence_getitem(sequence, inner_index) + + +def construct_and_extend_sequence( + a: T, b: T, c: T, d: T +) -> collections.abc.Sequence[collections.abc.Sequence[T]]: + """Constructs two sequences and combines them into a sequence of sequences.""" + seq1 = sequence_from_pair(a, b) + seq2 = sequence_from_pair(c, d) + return sequence_from_pair(seq1, seq2) + + +def transform_mapping_values( + key1: K, val1: T, key2: K, val2: T +) -> collections.abc.Mapping[K, collections.abc.Sequence[T]]: + """Creates a mapping where each value is wrapped in a sequence.""" + # Create mappings where each value becomes a single-element sequence + # Note: In a real implementation, we'd need a sequence_from_single function + # For now, using sequence_from_pair with the same value twice as a workaround + return mapping_from_pair(key1, sequence_from_pair(val1, val1)) + + +@pytest.mark.parametrize( + "seq,index,key", + [ + # Original test case: list of dicts with string keys and int values + ([{"a": 1}, {"b": 2}, {"c": 3}], 1, "b"), + # Different value types + ([{"x": "hello"}, {"y": "world"}, {"z": "test"}], 2, "z"), + ([{"name": 3.14}, {"value": 2.71}, {"constant": 1.41}], 0, "name"), + ([{"flag": True}, {"enabled": False}, {"active": True}], 1, "enabled"), + # Mixed value types in same dict (should still work) + ([{"a": [1, 2, 3]}, {"b": [4, 5, 6]}, {"c": [7, 8, 9]}], 0, "a"), + ([{"data": {"nested": "value"}}, {"info": {"deep": "data"}}], 1, "info"), + # Different key types + ([{1: "one"}, {2: "two"}, {3: "three"}], 2, 3), + ([{True: "yes"}, {False: "no"}], 0, True), + # Nested collections as values + ([{"items": [1, 2, 3]}, {"values": [4, 5, 6]}], 0, "items"), + ([{"matrix": [[1, 2], [3, 4]]}, {"grid": [[5, 6], [7, 8]]}], 1, "grid"), + ([{"sets": {1, 2, 3}}, {"groups": {4, 5, 6}}], 0, "sets"), + # Complex nested structures + ( + [ + {"users": [{"id": 1, "name": "Alice"}]}, + {"users": [{"id": 2, "name": "Bob"}]}, + ], + 1, + "users", + ), + ( + [ + {"config": {"db": {"host": "localhost", "port": 5432}}}, + {"config": {"cache": {"ttl": 300}}}, + ], + 0, + "config", + ), + # Edge cases with single element sequences + ([{"only": "one"}], 0, "only"), + # Tuples as values + ([{"point": (1, 2)}, {"coord": (3, 4)}, {"pos": (5, 6)}], 2, "pos"), + ([{"rgb": (255, 0, 0)}, {"hsv": (0, 100, 100)}], 0, "rgb"), + ], +) +def test_infer_composition_1(seq, index, key): + sig1 = inspect.signature(sequence_getitem) + sig2 = inspect.signature(mapping_getitem) + + sig12 = inspect.signature(sequence_mapping_getitem) + + inferred_type1 = substitute( + sig1.return_annotation, + unify(sig1, sig1.bind(nested_type(seq), nested_type(index))), + ) + + inferred_type2 = substitute( + sig2.return_annotation, + unify(sig2, sig2.bind(nested_type(inferred_type1), nested_type(key))), + ) + + inferred_type12 = substitute( + sig12.return_annotation, + unify( + sig12, + sig12.bind(nested_type(seq), nested_type(index), nested_type(key)), + ), + ) + + # check that the composed inference matches the direct inference + assert isinstance(unify(inferred_type2, inferred_type12), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(sequence_mapping_getitem(seq, index, key)), inferred_type12), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "mapping,key,index", + [ + # Dict of lists with string keys + ( + { + "fruits": ["apple", "banana", "cherry"], + "colors": ["red", "green", "blue"], + }, + "fruits", + 1, + ), + ({"numbers": [1, 2, 3, 4, 5], "primes": [2, 3, 5, 7, 11]}, "primes", 3), + # Different value types in sequences + ({"floats": [1.1, 2.2, 3.3], "constants": [3.14, 2.71, 1.41]}, "constants", 0), + ( + {"flags": [True, False, True, False], "states": [False, True, False]}, + "flags", + 2, + ), + # Nested structures + ( + {"matrix": [[1, 2, 3], [4, 5, 6], [7, 8, 9]], "identity": [[1, 0], [0, 1]]}, + "matrix", + 1, + ), + ( + {"teams": [{"name": "A", "score": 10}, {"name": "B", "score": 20}]}, + "teams", + 0, + ), + # Different key types + ( + { + 1: ["one", "uno", "un"], + 2: ["two", "dos", "deux"], + 3: ["three", "tres", "trois"], + }, + 2, + 1, + ), + ({True: ["yes", "true", "1"], False: ["no", "false", "0"]}, False, 2), + # Lists of different collection types + ( + {"data": [{"a": 1}, {"b": 2}, {"c": 3}], "info": [{"x": 10}, {"y": 20}]}, + "data", + 2, + ), + # Edge cases + ({"single": ["only"]}, "single", 0), + ({"empty_key": [], "full": [1, 2, 3]}, "full", 1), + # Complex nested case + ( + { + "users": [ + {"id": 1, "tags": ["admin", "user"]}, + {"id": 2, "tags": ["user", "guest"]}, + {"id": 3, "tags": ["guest"]}, + ] + }, + "users", + 1, + ), + # More diverse cases + ( + {"names": ["Alice", "Bob", "Charlie", "David"], "ages": [25, 30, 35, 40]}, + "names", + 3, + ), + ( + {"options": [[1, 2], [3, 4], [5, 6]], "choices": [[7], [8], [9]]}, + "options", + 2, + ), + # Deeply nested lists + ( + {"deep": [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], "shallow": [[9, 10]]}, + "deep", + 0, + ), + ], +) +def test_infer_composition_2(mapping, key, index): + sig1 = inspect.signature(mapping_getitem) + sig2 = inspect.signature(sequence_getitem) + + sig12 = inspect.signature(mapping_sequence_getitem) + + # First infer type of mapping_getitem(mapping, key) -> should be a sequence + inferred_type1 = substitute( + sig1.return_annotation, + unify(sig1, sig1.bind(nested_type(mapping), nested_type(key))), + ) + + # Then infer type of sequence_getitem(result_from_step1, index) -> should be element type + inferred_type2 = substitute( + sig2.return_annotation, + unify(sig2, sig2.bind(nested_type(inferred_type1), nested_type(index))), + ) + + # Directly infer type of mapping_sequence_getitem(mapping, key, index) + inferred_type12 = substitute( + sig12.return_annotation, + unify( + sig12, + sig12.bind(nested_type(mapping), nested_type(key), nested_type(index)), + ), + ) + + # The composed inference should match the direct inference + assert isinstance(unify(inferred_type2, inferred_type12), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify( + nested_type(mapping_sequence_getitem(mapping, key, index)), inferred_type12 + ), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "a,b,index", + [ + # Basic types + (1, 2, 0), + (1, 2, 1), + ("hello", "world", 0), + (3.14, 2.71, 1), + (True, False, 0), + # Complex types + ([1, 2], [3, 4], 1), + ({"a": 1}, {"b": 2}, 0), + ({1, 2}, {3, 4}, 1), + # Mixed but same types + ([1, 2, 3], [4, 5], 0), + ({"x": "a", "y": "b"}, {"z": "c"}, 1), + ], +) +def test_get_from_constructed_sequence(a, b, index): + """Test type inference through sequence construction and retrieval.""" + sig_construct = inspect.signature(sequence_from_pair) + sig_getitem = inspect.signature(sequence_getitem) + sig_composed = inspect.signature(get_from_constructed_sequence) + + # Infer type of sequence_from_pair(a, b) -> Sequence[T] + construct_subs = unify( + sig_construct, sig_construct.bind(nested_type(a), nested_type(b)) + ) + inferred_sequence_type = substitute(sig_construct.return_annotation, construct_subs) + + # Infer type of sequence_getitem(sequence, index) -> T + getitem_subs = unify( + sig_getitem, sig_getitem.bind(inferred_sequence_type, nested_type(index)) + ) + inferred_element_type = substitute(sig_getitem.return_annotation, getitem_subs) + + # Directly infer type of get_from_constructed_sequence(a, b, index) + direct_subs = unify( + sig_composed, + sig_composed.bind(nested_type(a), nested_type(b), nested_type(index)), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance( + unify(inferred_element_type, direct_type), collections.abc.Mapping + ) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(get_from_constructed_sequence(a, b, index)), direct_type), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "key,value,lookup_key", + [ + # Basic types + ("name", "Alice", "name"), + (1, "one", 1), + (True, "yes", True), + (3.14, "pi", 3.14), + # Complex value types + ("data", [1, 2, 3], "data"), + ("config", {"host": "localhost", "port": 8080}, "config"), + ("items", {1, 2, 3}, "items"), + # Different key types + (42, {"value": "answer"}, 42), + ("key", (1, 2, 3), "key"), + ], +) +def test_get_from_constructed_mapping(key, value, lookup_key): + """Test type inference through mapping construction and retrieval.""" + sig_construct = inspect.signature(mapping_from_pair) + sig_getitem = inspect.signature(mapping_getitem) + sig_composed = inspect.signature(get_from_constructed_mapping) + + # Infer type of mapping_from_pair(key, value) -> Mapping[K, V] + construct_subs = unify( + sig_construct, sig_construct.bind(nested_type(key), nested_type(value)) + ) + inferred_mapping_type = substitute(sig_construct.return_annotation, construct_subs) + + # Infer type of mapping_getitem(mapping, lookup_key) -> V + getitem_subs = unify( + sig_getitem, sig_getitem.bind(inferred_mapping_type, nested_type(lookup_key)) + ) + inferred_value_type = substitute(sig_getitem.return_annotation, getitem_subs) + + # Directly infer type of get_from_constructed_mapping(key, value, lookup_key) + direct_subs = unify( + sig_composed, + sig_composed.bind( + nested_type(key), nested_type(value), nested_type(lookup_key) + ), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(inferred_value_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify( + nested_type(get_from_constructed_mapping(key, value, lookup_key)), + direct_type, + ), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "key1,val1,key2,val2,index", + [ + # Basic case + ("a", 1, "b", 2, 0), + ("x", "hello", "y", "world", 1), + # Different types + (1, "one", 2, "two", 0), + (True, 1.0, False, 0.0, 1), + # Complex values + ("list1", [1, 2], "list2", [3, 4], 0), + ("dict1", {"a": 1}, "dict2", {"b": 2}, 1), + ], +) +def test_sequence_of_mappings(key1, val1, key2, val2, index): + """Test type inference for creating a sequence of mappings.""" + sig_map = inspect.signature(mapping_from_pair) + sig_seq = inspect.signature(sequence_from_pair) + sig_composed = inspect.signature(sequence_of_mappings) + + # Step 1: Infer types of the two mappings + map1_subs = unify(sig_map, sig_map.bind(nested_type(key1), nested_type(val1))) + map1_type = substitute(sig_map.return_annotation, map1_subs) + + # Step 2: Infer type of sequence containing these mappings + # We need to unify the two mapping types first + unified_map_type = map1_type # Assuming they're compatible + + seq_subs = unify(sig_seq, sig_seq.bind(unified_map_type, unified_map_type)) + seq_type = substitute(sig_seq.return_annotation, seq_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind( + nested_type(key1), nested_type(val1), nested_type(key2), nested_type(val2) + ), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The types should match + assert isinstance(unify(seq_type, direct_type), collections.abc.Mapping) + + # Note: nested_type(sequence_of_mappings(...)) returns concrete types (list[dict[K,V]]) + # while our function signature uses abstract types (Sequence[Mapping[K,V]]) + # This is expected behavior - concrete implementations vs abstract interfaces + + +@pytest.mark.parametrize( + "k1,v1,v2,k2,v3,v4,outer_idx,inner_key,inner_idx", + [ + # Basic test case + ("first", 1, 2, "second", 3, 4, 0, "first", 1), + ("a", "x", "y", "b", "z", "w", 1, "b", 0), + # Different types + (1, 10.0, 20.0, 2, 30.0, 40.0, 0, 1, 1), + ("data", [1], [2], "info", [3], [4], 1, "info", 0), + ], +) +def test_double_nested_get(k1, v1, v2, k2, v3, v4, outer_idx, inner_key, inner_idx): + """Test type inference through deeply nested structure construction and retrieval.""" + # Get signatures for all functions involved + sig_nested = inspect.signature(nested_sequence_mapping) + sig_seq_get = inspect.signature(sequence_getitem) + sig_map_get = inspect.signature(mapping_getitem) + sig_composed = inspect.signature(double_nested_get) + + # Step 1: Infer type of nested_sequence_mapping construction + nested_subs = unify( + sig_nested, + sig_nested.bind( + nested_type(k1), + nested_type(v1), + nested_type(v2), + nested_type(k2), + nested_type(v3), + nested_type(v4), + ), + ) + nested_seq_type = substitute(sig_nested.return_annotation, nested_subs) + # This should be Sequence[Mapping[K, Sequence[T]]] + + # Step 2: Get element from outer sequence + outer_get_subs = unify( + sig_seq_get, sig_seq_get.bind(nested_seq_type, nested_type(outer_idx)) + ) + mapping_type = substitute(sig_seq_get.return_annotation, outer_get_subs) + # This should be Mapping[K, Sequence[T]] + + # Step 3: Get sequence from mapping + inner_map_subs = unify( + sig_map_get, sig_map_get.bind(mapping_type, nested_type(inner_key)) + ) + sequence_type = substitute(sig_map_get.return_annotation, inner_map_subs) + # This should be Sequence[T] + + # Step 4: Get element from inner sequence + final_get_subs = unify( + sig_seq_get, sig_seq_get.bind(sequence_type, nested_type(inner_idx)) + ) + composed_type = substitute(sig_seq_get.return_annotation, final_get_subs) + # This should be T + + # Direct inference on the composed function + direct_subs = unify( + sig_composed, + sig_composed.bind( + nested_type(k1), + nested_type(v1), + nested_type(v2), + nested_type(k2), + nested_type(v3), + nested_type(v4), + nested_type(outer_idx), + nested_type(inner_key), + nested_type(inner_idx), + ), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify( + nested_type( + double_nested_get( + k1, v1, v2, k2, v3, v4, outer_idx, inner_key, inner_idx + ) + ), + direct_type, + ), + collections.abc.Mapping, + ) From e173387990ebf18664b8fcb211228a76b66f902a Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 14:42:37 -0400 Subject: [PATCH 65/76] add tests for function types --- tests/test_internals_unification.py | 402 ++++++++++++++++++++++++++++ 1 file changed, 402 insertions(+) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index d113ec1a..95b00e4f 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -16,11 +16,13 @@ K = typing.Any V = typing.Any U = typing.Any + W = typing.Any else: T = typing.TypeVar("T") K = typing.TypeVar("K") V = typing.TypeVar("V") U = typing.TypeVar("U") + W = typing.TypeVar("W") @pytest.mark.parametrize( @@ -742,6 +744,178 @@ def transform_mapping_values( return mapping_from_pair(key1, sequence_from_pair(val1, val1)) +def call_func( + func: collections.abc.Callable[[T], V], + arg: T, +) -> V: + """Calls a function with a single argument.""" + return func(arg) + + +def call_binary_func( + func: collections.abc.Callable[[T, U], V], + arg1: T, + arg2: U, +) -> V: + """Calls a binary function with two arguments.""" + return func(arg1, arg2) + + +def map_sequence( + f: collections.abc.Callable[[T], U], + seq: collections.abc.Sequence[T], +) -> collections.abc.Sequence[U]: + """Applies a function to each element in a sequence.""" + return [call_func(f, x) for x in seq] + + +def compose_mappings( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U], V], +) -> collections.abc.Callable[[T], V]: + """Composes two functions that operate on mappings.""" + + def composed(x: T) -> V: + return call_func(g, call_func(f, x)) + + return composed + + +def compose_binary( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U, U], V], +) -> collections.abc.Callable[[T], V]: + """Composes a unary function with a binary function.""" + + def composed(x: T) -> V: + return call_binary_func(g, call_func(f, x), call_func(f, x)) + + return composed + + +def apply_to_sequence_element( + f: collections.abc.Callable[[T], U], + seq: collections.abc.Sequence[T], + index: int, +) -> U: + """Gets an element from a sequence and applies a function to it.""" + element = sequence_getitem(seq, index) + return call_func(f, element) + + +def map_and_get( + f: collections.abc.Callable[[T], U], + seq: collections.abc.Sequence[T], + index: int, +) -> U: + """Maps a function over a sequence and gets element at index.""" + mapped_seq = map_sequence(f, seq) + return sequence_getitem(mapped_seq, index) + + +def compose_and_apply( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U], V], + value: T, +) -> V: + """Composes two functions and applies the result to a value.""" + composed = compose_mappings(f, g) + return call_func(composed, value) + + +def double_compose_apply( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U], V], + h: collections.abc.Callable[[V], W], + value: T, +) -> W: + """Composes three functions and applies to a value.""" + fg = compose_mappings(f, g) + fgh = compose_mappings(fg, h) + return call_func(fgh, value) + + +def binary_on_sequence_elements( + f: collections.abc.Callable[[T, T], U], + seq: collections.abc.Sequence[T], + index1: int, + index2: int, +) -> U: + """Gets two elements from a sequence and applies a binary function.""" + elem1 = sequence_getitem(seq, index1) + elem2 = sequence_getitem(seq, index2) + return call_binary_func(f, elem1, elem2) + + +def map_sequence_and_apply_binary( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U, U], V], + seq: collections.abc.Sequence[T], + index1: int, + index2: int, +) -> V: + """Maps a function over sequence, then applies binary function to two elements.""" + mapped = map_sequence(f, seq) + elem1 = sequence_getitem(mapped, index1) + elem2 = sequence_getitem(mapped, index2) + return call_binary_func(g, elem1, elem2) + + +def construct_apply_and_get( + f: collections.abc.Callable[[T], U], + a: T, + b: T, + index: int, +) -> U: + """Constructs a sequence, applies function to elements, and gets one.""" + seq = sequence_from_pair(a, b) + return apply_to_sequence_element(f, seq, index) + + +def sequence_function_composition( + funcs: collections.abc.Sequence[collections.abc.Callable[[T], T]], + value: T, +) -> T: + """Applies a sequence of functions in order to a value.""" + result = value + for func in funcs: + result = call_func(func, result) + return result + + +def map_with_constructed_function( + f: collections.abc.Callable[[T], U], + g: collections.abc.Callable[[U], V], + seq: collections.abc.Sequence[T], +) -> collections.abc.Sequence[V]: + """Composes two functions and maps the result over a sequence.""" + composed = compose_mappings(f, g) + return map_sequence(composed, seq) + + +def cross_apply_binary( + f: collections.abc.Callable[[T, U], V], + seq1: collections.abc.Sequence[T], + seq2: collections.abc.Sequence[U], + index1: int, + index2: int, +) -> V: + """Gets elements from two sequences and applies a binary function.""" + elem1 = sequence_getitem(seq1, index1) + elem2 = sequence_getitem(seq2, index2) + return call_binary_func(f, elem1, elem2) + + +def nested_function_application( + outer_f: collections.abc.Callable[[T], collections.abc.Callable[[U], V]], + inner_arg: U, + outer_arg: T, +) -> V: + """Applies a function that returns a function, then applies the result.""" + inner_f = call_func(outer_f, outer_arg) + return call_func(inner_f, inner_arg) + + @pytest.mark.parametrize( "seq,index,key", [ @@ -1187,3 +1361,231 @@ def test_double_nested_get(k1, v1, v2, k2, v3, v4, outer_idx, inner_key, inner_i ), collections.abc.Mapping, ) + + +@pytest.mark.parametrize( + "f,seq,index", + [ + # Basic function applications + (lambda x: x * 2, [1, 2, 3], 0), + (lambda x: x * 2, [1, 2, 3], 2), + (lambda x: x.upper(), ["hello", "world"], 1), + (lambda x: len(x), ["a", "bb", "ccc"], 2), + (lambda x: x + 1.0, [1.0, 2.0, 3.0], 1), + ], +) +def test_apply_to_sequence_element(f, seq, index): + """Test type inference through sequence access and function application.""" + sig_getitem = inspect.signature(sequence_getitem) + sig_call = inspect.signature(call_func) + sig_composed = inspect.signature(apply_to_sequence_element) + + # Step 1: Infer type of sequence_getitem(seq, index) -> T + getitem_subs = unify( + sig_getitem, sig_getitem.bind(nested_type(seq), nested_type(index)) + ) + element_type = substitute(sig_getitem.return_annotation, getitem_subs) + + # Step 2: Infer type of call_func(f, element) -> U + call_subs = unify(sig_call, sig_call.bind(nested_type(f), element_type)) + composed_type = substitute(sig_call.return_annotation, call_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind(nested_type(f), nested_type(seq), nested_type(index)), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(apply_to_sequence_element(f, seq, index)), direct_type), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "f,seq,index", + [ + # Basic transformations + (lambda x: x * 2, [1, 2, 3], 1), + (lambda x: x.upper(), ["hello", "world"], 0), + (lambda x: len(x), ["a", "bb", "ccc"], 2), + (lambda x: x + 1, [10, 20, 30], 0), + ], +) +def test_map_and_get(f, seq, index): + """Test type inference through mapping and element retrieval.""" + sig_map = inspect.signature(map_sequence) + sig_getitem = inspect.signature(sequence_getitem) + sig_composed = inspect.signature(map_and_get) + + # Step 1: Infer type of map_sequence(f, seq) -> Sequence[U] + map_subs = unify(sig_map, sig_map.bind(nested_type(f), nested_type(seq))) + mapped_type = substitute(sig_map.return_annotation, map_subs) + + # Step 2: Infer type of sequence_getitem(mapped_seq, index) -> U + getitem_subs = unify(sig_getitem, sig_getitem.bind(mapped_type, nested_type(index))) + composed_type = substitute(sig_getitem.return_annotation, getitem_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind(nested_type(f), nested_type(seq), nested_type(index)), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(map_and_get(f, seq, index)), direct_type), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "f,g,value", + [ + # Basic function compositions + (lambda x: x * 2, lambda x: x + 1, 5), + (lambda x: str(x), lambda x: x.upper(), 42), + (lambda x: len(x), lambda x: x * 2, "hello"), + (lambda x: [x], lambda x: x[0], 1), + ], +) +def test_compose_and_apply(f, g, value): + """Test type inference through function composition and application.""" + sig_compose = inspect.signature(compose_mappings) + sig_call = inspect.signature(call_func) + sig_composed = inspect.signature(compose_and_apply) + + # Step 1: Infer type of compose_mappings(f, g) -> Callable[[T], V] + compose_subs = unify(sig_compose, sig_compose.bind(nested_type(f), nested_type(g))) + composed_func_type = substitute(sig_compose.return_annotation, compose_subs) + + # Step 2: Infer type of call_func(composed, value) -> V + call_subs = unify(sig_call, sig_call.bind(composed_func_type, nested_type(value))) + result_type = substitute(sig_call.return_annotation, call_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind(nested_type(f), nested_type(g), nested_type(value)), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(result_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(compose_and_apply(f, g, value)), direct_type), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "f,a,b,index", + [ + # Basic constructions and applications + (lambda x: x * 2, 1, 2, 0), + (lambda x: x * 2, 1, 2, 1), + (lambda x: x.upper(), "hello", "world", 0), + (lambda x: len(x), "a", "bb", 1), + ], +) +def test_construct_apply_and_get(f, a, b, index): + """Test type inference through construction, application, and retrieval.""" + sig_construct = inspect.signature(sequence_from_pair) + sig_apply = inspect.signature(apply_to_sequence_element) + sig_composed = inspect.signature(construct_apply_and_get) + + # Step 1: Infer type of sequence_from_pair(a, b) -> Sequence[T] + construct_subs = unify( + sig_construct, sig_construct.bind(nested_type(a), nested_type(b)) + ) + seq_type = substitute(sig_construct.return_annotation, construct_subs) + + # Step 2: Infer type of apply_to_sequence_element(f, seq, index) -> U + apply_subs = unify( + sig_apply, sig_apply.bind(nested_type(f), seq_type, nested_type(index)) + ) + composed_type = substitute(sig_apply.return_annotation, apply_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind( + nested_type(f), nested_type(a), nested_type(b), nested_type(index) + ), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify(nested_type(construct_apply_and_get(f, a, b, index)), direct_type), + collections.abc.Mapping, + ) + + +@pytest.mark.parametrize( + "f,seq,index1,index2", + [ + # Basic binary operations + (lambda x, y: x + y, [1, 2, 3], 0, 1), + (lambda x, y: x + y, [1, 2, 3], 1, 2), + (lambda x, y: x + y, ["hello", "world", "test"], 0, 2), + (lambda x, y: x * y, [2, 3, 4], 0, 2), + ], +) +def test_binary_on_sequence_elements(f, seq, index1, index2): + """Test type inference through sequence access and binary function application.""" + sig_getitem = inspect.signature(sequence_getitem) + sig_call_binary = inspect.signature(call_binary_func) + sig_composed = inspect.signature(binary_on_sequence_elements) + + # Step 1: Infer types of sequence_getitem calls + getitem1_subs = unify( + sig_getitem, sig_getitem.bind(nested_type(seq), nested_type(index1)) + ) + elem1_type = substitute(sig_getitem.return_annotation, getitem1_subs) + + getitem2_subs = unify( + sig_getitem, sig_getitem.bind(nested_type(seq), nested_type(index2)) + ) + elem2_type = substitute(sig_getitem.return_annotation, getitem2_subs) + + # Step 2: Infer type of call_binary_func(f, elem1, elem2) -> V + call_subs = unify( + sig_call_binary, sig_call_binary.bind(nested_type(f), elem1_type, elem2_type) + ) + composed_type = substitute(sig_call_binary.return_annotation, call_subs) + + # Direct inference + direct_subs = unify( + sig_composed, + sig_composed.bind( + nested_type(f), nested_type(seq), nested_type(index1), nested_type(index2) + ), + ) + direct_type = substitute(sig_composed.return_annotation, direct_subs) + + # The composed inference should match the direct inference + assert isinstance(unify(composed_type, direct_type), collections.abc.Mapping) + + # check that the result of nested_type on the value of the composition unifies with the inferred type + assert isinstance( + unify( + nested_type(binary_on_sequence_elements(f, seq, index1, index2)), + direct_type, + ), + collections.abc.Mapping, + ) From 6b532af57accbf0fc9dad9b43bdef051cb5d8430 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 16:26:31 -0400 Subject: [PATCH 66/76] ellipsis and paramspec --- effectful/internals/unification.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index cff2f942..8c27acc4 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -323,10 +323,14 @@ def _( @unify.register def _( typ: typing.ParamSpec, - subtyp: typing.ParamSpec | collections.abc.Sequence, + subtyp: typing.ParamSpec | collections.abc.Sequence | types.EllipsisType, subs: Substitutions = {}, ) -> Substitutions: - return subs if typ is subtyp else {typ: subtyp, **subs} + return ( + subs + if typ is subtyp or isinstance(subtyp, types.EllipsisType) + else {typ: subtyp, **subs} + ) @unify.register @@ -335,7 +339,10 @@ def _( subtyp: types.EllipsisType | collections.abc.Sequence, subs: Substitutions = {}, ) -> Substitutions: - return subs + if isinstance(subtyp, types.EllipsisType | collections.abc.Sequence): + return subs + else: + raise TypeError(f"Cannot unify type {typ} with {subtyp} given {subs}. ") def _freshen(tp: typing.Any): From d9578bbecaa73b7982bf218735baee9c6931c3e1 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 17:15:08 -0400 Subject: [PATCH 67/76] variadic tuple logic and union tests --- effectful/internals/unification.py | 29 ++++++++++++++++++++++++++++- tests/test_internals_unification.py | 26 ++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 8c27acc4..7084c7a0 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -258,7 +258,34 @@ def _( subtyp: type | types.GenericAlias | typing.TypeVar | types.UnionType, subs: Substitutions = {}, ) -> Substitutions: - if isinstance(subtyp, GenericAlias): + if ( + issubclass(typing.get_origin(typ), tuple) + and len(typing.get_args(typ)) == 2 + and typing.get_args(typ)[-1] is Ellipsis + ): + # Logic for normalizing variadic tuples in typ + subs = unify(typing.get_origin(typ), typing.get_origin(subtyp), subs) + return unify(collections.abc.Sequence[typing.get_args(typ)[0]], subtyp, subs) # type: ignore + elif ( + isinstance(subtyp, GenericAlias) + and issubclass(typing.get_origin(subtyp), tuple) + and typing.get_args(subtyp)[-1] is Ellipsis + ): + # Logic for normalizing variadic tuples in subtyp + subs = unify(typing.get_origin(typ), typing.get_origin(subtyp), subs) + return unify(typ, collections.abc.Sequence[typing.get_args(subtyp)[0]], subs) # type: ignore + elif ( + isinstance(subtyp, GenericAlias) + and issubclass(typing.get_origin(subtyp), tuple) + and not issubclass(typing.get_origin(typ), tuple) + ): + # non-tuple, non-variadic tuple + subs = unify(typing.get_origin(typ), typing.get_origin(subtyp), subs) + subtyp_ = typing.get_args(subtyp)[0] + for st in typing.get_args(subtyp)[1:]: + subtyp_ = subtyp_ | st + return unify(typing.get_args(typ)[0], subtyp_, subs) + elif isinstance(subtyp, GenericAlias): subs = unify(typing.get_origin(typ), typing.get_origin(subtyp), subs) return unify(typing.get_args(typ), typing.get_args(subtyp), subs) else: diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 95b00e4f..c2f16156 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -333,6 +333,32 @@ def test_unify_failure( unify(typ, subtyp, {}) +def test_unify_union_1(): + assert unify(int | str, int | str) == {} + assert unify(int | str, str) == {} + assert unify(int | str, int) == {} + + assert unify(T, int | str) == {T: int | str} + assert unify(T | V, int | str) == {T: int | str, V: int | str} + + +def test_unify_tuple_variadic(): + assert unify(tuple[T, ...], tuple[int, ...]) == {T: int} + assert unify(tuple[T, ...], tuple[int]) == {T: int} + assert unify(tuple[T, ...], tuple[int, int]) == {T: int} + assert unify(tuple[T, ...], tuple[int, str]) == {T: int | str} + assert unify(collections.abc.Sequence[T], tuple[int, ...]) == {T: int} + + +def test_unify_tuple_non_variadic(): + assert unify(tuple[T], tuple[int | str]) == {T: int | str} + assert unify(tuple[T, V], tuple[int, str]) == {T: int, V: str} + assert unify(tuple[T, T], tuple[int, int]) == {T: int} + assert unify(tuple[T, T, T], tuple[str, str, str]) == {T: str} + assert unify(collections.abc.Sequence[T], tuple[int, str]) == {T: int | str} + assert unify(collections.abc.Sequence[T], tuple[int, int]) == {T: int} + + # Test functions with various type patterns def identity(x: T) -> T: return x From 05a78490034a862e224f739fae71a3eaba2d16be Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 10 Jul 2025 17:31:46 -0400 Subject: [PATCH 68/76] fix union unify pattern --- effectful/internals/unification.py | 26 ++++++++++++++++---------- tests/test_internals_unification.py | 1 - 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 7084c7a0..eb674c61 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -236,20 +236,26 @@ def _( @unify.register def _( typ: types.UnionType, - subtyp: type, + subtyp: type | types.UnionType, subs: Substitutions = {}, ) -> Substitutions: - any_succeeded = False - for arg in typing.get_args(typ): - try: - subs = unify(arg, subtyp, subs) - any_succeeded = True - except TypeError: # noqa - continue - if any_succeeded: + if isinstance(subtyp, types.UnionType): + # If subtyp is a union, try to unify with each argument + for arg in typing.get_args(subtyp): + subs = unify(typ, arg, subs) return subs else: - raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") + any_succeeded = False + for arg in typing.get_args(typ): + try: + subs = unify(arg, subtyp, subs) + any_succeeded = True + except TypeError: # noqa + continue + if any_succeeded: + return subs + else: + raise TypeError(f"Cannot unify {typ} with {subtyp} given {subs}") @unify.register diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index c2f16156..538625ce 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -339,7 +339,6 @@ def test_unify_union_1(): assert unify(int | str, int) == {} assert unify(T, int | str) == {T: int | str} - assert unify(T | V, int | str) == {T: int | str, V: int | str} def test_unify_tuple_variadic(): From e0cf9b319e0dcc318c2ad94ec2244146b79c3e07 Mon Sep 17 00:00:00 2001 From: eb8680 Date: Fri, 11 Jul 2025 12:09:39 -0400 Subject: [PATCH 69/76] Use unification to implement `Operation.__type_rule__` (#300) * Use infer_return_type to implement Operation.__type_rule__ * move some logic out of infer_return_type * dont add defaults * update * remove duplicate default param logic * fix and format * lint * fix semiring test * address comment --- docs/source/semi_ring.py | 24 +++++-------- effectful/ops/semantics.py | 21 +++++++++++- effectful/ops/syntax.py | 62 ++++++++++++---------------------- tests/test_handlers_numbers.py | 17 +++++++--- tests/test_semi_ring.py | 42 +++++++++++------------ 5 files changed, 84 insertions(+), 82 deletions(-) diff --git a/docs/source/semi_ring.py b/docs/source/semi_ring.py index d8a8bcad..7b8899a2 100644 --- a/docs/source/semi_ring.py +++ b/docs/source/semi_ring.py @@ -1,7 +1,7 @@ import collections.abc import operator import types -from typing import Annotated, ParamSpec, Tuple, TypeVar, Union, cast, overload +from typing import Annotated, ParamSpec, TypeVar, cast, overload import effectful.handlers.numbers # noqa: F401 from effectful.ops.semantics import coproduct, evaluate, fwd, handler @@ -70,17 +70,17 @@ def Let( @defop -def Record(**kwargs: T) -> dict[str, T]: +def Record(**kwargs: T) -> collections.abc.Mapping[str, T]: raise NotImplementedError @defop -def Field(record: dict[str, T], key: str) -> T: +def Field(record: collections.abc.Mapping[str, T], key: str) -> T: raise NotImplementedError @defop -def Dict(*contents: Union[K, V]) -> SemiRingDict[K, V]: +def Dict(*contents: tuple[K, V]) -> SemiRingDict[K, V]: raise NotImplementedError @@ -100,20 +100,14 @@ def add(x: T, y: T) -> T: ops.Field = Field -def eager_dict(*contents: Tuple[K, V]) -> SemiRingDict[K, V]: - if not any(isinstance(v, Term) for v in contents): - if len(contents) % 2 != 0: - raise ValueError("Dict requires an even number of arguments") - - kv = [] - for i in range(0, len(contents), 2): - kv.append((contents[i], contents[i + 1])) - return SemiRingDict(kv) +def eager_dict(*contents: tuple[K, V]) -> SemiRingDict[K, V]: + if not any(isinstance(v, Term) for kv in contents for v in kv): + return SemiRingDict(list(contents)) else: return fwd() -def eager_record(**kwargs: T) -> dict[str, T]: +def eager_record(**kwargs: T) -> collections.abc.Mapping[str, T]: if not any(isinstance(v, Term) for v in kwargs.values()): return dict(**kwargs) else: @@ -223,7 +217,7 @@ def vertical_fusion(e1: T, x: Operation[[], T], e2: S) -> S: ) term: SemiRingDict[int, int] = Let( - Sum(x(), k, v, Dict(k(), v() + 1)), y, Sum(y(), k, v, Dict(k(), v() + 1)) + Sum(x(), k, v, Dict((k(), v() + 1))), y, Sum(y(), k, v, Dict((k(), v() + 1))) ) print("Without optimization:", term) diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 38eb0c86..210ffaa4 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -1,5 +1,7 @@ import contextlib import functools +import types +import typing from collections.abc import Callable from typing import Any, TypeVar @@ -294,7 +296,24 @@ def typeof(term: Expr[T]) -> type[T]: from effectful.internals.runtime import interpreter with interpreter({apply: lambda _, op, *a, **k: op.__type_rule__(*a, **k)}): - return evaluate(term) if isinstance(term, Term) else type(term) # type: ignore + if isinstance(term, Term): + # If term is a Term, we evaluate it to get its type + tp = evaluate(term) + if isinstance(tp, typing.TypeVar): + tp = ( + tp.__bound__ + if tp.__bound__ + else tp.__constraints__[0] + if tp.__constraints__ + else object + ) + if isinstance(tp, types.UnionType): + raise TypeError( + f"Cannot determine type of {term} because it is a union type: {tp}" + ) + return typing.get_origin(tp) or tp # type: ignore + else: + return type(term) def fvsof(term: Expr[S]) -> set[Operation]: diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 36f36812..c9c07a0c 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -598,48 +598,28 @@ def __fvs_rule__( return tuple(result_sig.args), dict(result_sig.kwargs) def __type_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> type[V]: - def unwrap_annotation(typ): - """Unwrap Annotated types.""" - return ( - typing.get_args(typ)[0] if typing.get_origin(typ) is Annotated else typ - ) - - def drop_params(typ): - """Strip parameters from polymorphic types.""" - origin = typing.get_origin(typ) - return typ if origin is None else origin - - sig = self.__signature__ - bound_sig = sig.bind(*args, **kwargs) - bound_sig.apply_defaults() - - anno = sig.return_annotation - anno = unwrap_annotation(anno) - - if anno is None: - return typing.cast(type[V], type(None)) - - if anno is inspect.Signature.empty: - return typing.cast(type[V], object) + from effectful.internals.unification import ( + freetypevars, + nested_type, + substitute, + unify, + ) - if isinstance(anno, typing.TypeVar): - # rudimentary but sound special-case type inference sufficient for syntax ops: - # if the return type annotation is a TypeVar, - # look for a parameter with the same annotation and return its type, - # otherwise give up and return Any/object - for name, param in bound_sig.signature.parameters.items(): - param_typ = unwrap_annotation(param.annotation) - if param_typ is anno and param.kind not in ( - inspect.Parameter.VAR_POSITIONAL, - inspect.Parameter.VAR_KEYWORD, - ): - arg = bound_sig.arguments[name] - tp: type[V] = type(arg) if not isinstance(arg, type) else arg - return drop_params(tp) + return_anno = self.__signature__.return_annotation + if typing.get_origin(return_anno) is typing.Annotated: + return_anno = typing.get_args(return_anno)[0] + if return_anno is inspect.Parameter.empty: return typing.cast(type[V], object) + elif return_anno is None: + return type(None) # type: ignore + elif not freetypevars(return_anno): + return return_anno - return drop_params(anno) + type_args = tuple(nested_type(a) for a in args) + type_kwargs = {k: nested_type(v) for k, v in kwargs.items()} + bound_sig = self.__signature__.bind(*type_args, **type_kwargs) + return substitute(return_anno, unify(self.__signature__, bound_sig)) # type: ignore def __repr__(self): return f"_BaseOperation({self._default}, name={self.__name__}, freshening={self._freshening})" @@ -670,6 +650,9 @@ def func(*args, **kwargs): @defop.register(type) +@defop.register(typing.cast(type, types.GenericAlias)) +@defop.register(typing.cast(type, typing._GenericAlias)) # type: ignore +@defop.register(typing.cast(type, types.UnionType)) def _(t: type[T], *, name: str | None = None) -> Operation[[], T]: def func() -> t: # type: ignore raise NotImplementedError @@ -1006,9 +989,6 @@ def _(op, *args, **kwargs): base_term = __dispatch(typing.cast(type[T], object))(op, *args_, **kwargs_) tp = typeof(base_term) - if tp is typing.Union: - raise ValueError("Terms that return Union types are not supported.") - assert isinstance(tp, type) typed_term = __dispatch(tp)(op, *args_, **kwargs_) return typed_term diff --git a/tests/test_handlers_numbers.py b/tests/test_handlers_numbers.py index a995ff8c..7635d548 100644 --- a/tests/test_handlers_numbers.py +++ b/tests/test_handlers_numbers.py @@ -1,6 +1,8 @@ import collections +import collections.abc import logging import os +import typing import pytest @@ -11,6 +13,9 @@ logger = logging.getLogger(__name__) +T = typing.TypeVar("T") +S = typing.TypeVar("S") + def test_lambda_calculus_1(): x, y = defop(int), defop(int) @@ -40,7 +45,11 @@ def test_lambda_calculus_2(): def test_lambda_calculus_3(): - x, y, f = defop(int), defop(int), defop(collections.abc.Callable) + x, y, f = ( + defop(int), + defop(int), + defop(collections.abc.Callable[[int], collections.abc.Callable[[int], int]]), + ) with handler(eager_mixed): f2 = Lam(x, Lam(y, (x() + y()))) @@ -51,8 +60,8 @@ def test_lambda_calculus_3(): def test_lambda_calculus_4(): x, f, g = ( defop(int), - defop(collections.abc.Callable), - defop(collections.abc.Callable), + defop(collections.abc.Callable[[T], T]), + defop(collections.abc.Callable[[T], T]), ) with handler(eager_mixed): @@ -177,7 +186,7 @@ def f2(x: int, y: int) -> int: return x + y @trace - def app2(f: collections.abc.Callable, x: int, y: int) -> int: + def app2(f: collections.abc.Callable[[int, int], int], x: int, y: int) -> int: return f(x, y) assert app2(f2, 1, 2) == 3 diff --git a/tests/test_semi_ring.py b/tests/test_semi_ring.py index 4f85b92a..0418b8c8 100644 --- a/tests/test_semi_ring.py +++ b/tests/test_semi_ring.py @@ -1,10 +1,14 @@ import random +import typing -from docs.source.semi_ring import Dict, Field, Let, Sum, eager, ops, opt +from docs.source.semi_ring import Dict, Field, Let, SemiRingDict, Sum, eager, ops, opt from effectful.ops.semantics import handler from effectful.ops.syntax import defop, trace from effectful.ops.types import Term +S = typing.TypeVar("S") +T = typing.TypeVar("T") + @trace def add1(v: int) -> int: @@ -12,32 +16,32 @@ def add1(v: int) -> int: def test_simple_sum(): - x = defop(str, name="x") - y = defop(object, name="y") + x = defop(SemiRingDict[str, T], name="x") + y = defop(SemiRingDict[str, T], name="y") k = defop(str, name="k") v = defop(int, name="v") with handler(eager): - e = Sum(Dict("a", 1, "b", 2), k, v, Dict("v", v())) + e = Sum(Dict(("a", 1), ("b", 2)), k, v, Dict(("v", v()))) assert e["v"] == 3 with handler(eager): - e = Let(Dict("a", 1, "b", 2), x, Field(x(), "b")) + e = Let(Dict(("a", 1), ("b", 2)), x, Field(x(), "b")) assert e == 2 with handler(eager): - e = Sum(Dict("a", 1, "b", 2), k, v, Dict(k(), add1(add1(v())))) + e = Sum(Dict(("a", 1), ("b", 2)), k, v, Dict((k(), add1(add1(v()))))) assert e["a"] == 3 assert e["b"] == 4 with handler(eager), handler(opt): e = Let( - Dict("a", 1, "b", 2), + Dict(("a", 1), ("b", 2)), x, Let( - Sum(x(), k, v, Dict(k(), add1(v()))), + Sum(x(), k, v, Dict((k(), add1(v())))), y, - Sum(y(), k, v, Dict(k(), add1(v()))), + Sum(y(), k, v, Dict((k(), add1(v())))), ), ) assert e["a"] == 3 @@ -45,19 +49,19 @@ def test_simple_sum(): def fusion_test(d): - x = defop(object, name="x") - y = defop(object, name="y") - k = defop(object, name="k") - v = defop(object, name="v") + x = defop(SemiRingDict[S, T], name="x") + y = defop(SemiRingDict[S, T], name="y") + k = defop(str, name="k") + v = defop(int, name="v") return ( Let( d, x, Let( - Sum(x(), k, v, Dict(k(), add1(v()))), + Sum(x(), k, v, Dict((k(), add1(v())))), y, - Sum(y(), k, v, Dict(k(), add1(v()))), + Sum(y(), k, v, Dict((k(), add1(v())))), ), ), (x, y, k, v), @@ -65,15 +69,11 @@ def fusion_test(d): def make_dict(n): - kv = [] - for i in range(n): - kv.append(i) - kv.append(random.randint(1, 10)) - return Dict(*kv) + return Dict(*[(i, random.randint(1, 10)) for i in range(n)]) def test_fusion_term(): - dvar = defop(object, name="dvar") + dvar = defop(SemiRingDict[str, T], name="dvar") with handler(eager), handler(opt): result, (x, _, k, v) = fusion_test(dvar()) From 2345f12935920ac587579b22924b648926e90778 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 11 Jul 2025 14:02:20 -0400 Subject: [PATCH 70/76] update to 3.12 --- tests/test_internals_unification.py | 92 ++++++++++++++--------------- 1 file changed, 46 insertions(+), 46 deletions(-) diff --git a/tests/test_internals_unification.py b/tests/test_internals_unification.py index 538625ce..f554226e 100644 --- a/tests/test_internals_unification.py +++ b/tests/test_internals_unification.py @@ -359,35 +359,35 @@ def test_unify_tuple_non_variadic(): # Test functions with various type patterns -def identity(x: T) -> T: +def identity[T](x: T) -> T: return x -def make_pair(x: T, y: V) -> tuple[T, V]: +def make_pair[T, V](x: T, y: V) -> tuple[T, V]: return (x, y) -def wrap_in_list(x: T) -> list[T]: +def wrap_in_list[T](x: T) -> list[T]: return [x] -def get_first(items: list[T]) -> T: +def get_first[T](items: list[T]) -> T: return items[0] -def getitem_mapping(mapping: collections.abc.Mapping[K, V], key: K) -> V: +def getitem_mapping[K, V](mapping: collections.abc.Mapping[K, V], key: K) -> V: return mapping[key] -def dict_values(d: dict[K, V]) -> list[V]: +def dict_values[K, V](d: dict[K, V]) -> list[V]: return list(d.values()) -def process_callable(func: collections.abc.Callable[[T], V], arg: T) -> V: +def process_callable[T, V](func: collections.abc.Callable[[T], V], arg: T) -> V: return func(arg) -def chain_callables( +def chain_callables[T, U, V]( f: collections.abc.Callable[[T], U], g: collections.abc.Callable[[U], V] ) -> collections.abc.Callable[[T], V]: def result(x: T) -> V: @@ -400,23 +400,23 @@ def constant_func() -> int: return 42 -def multi_generic(a: T, b: list[T], c: dict[K, V]) -> tuple[T, K, V]: +def multi_generic[T, K, V](a: T, b: list[T], c: dict[K, V]) -> tuple[T, K, V]: return (a, next(iter(c.keys())), next(iter(c.values()))) -def same_type_twice(x: T, y: T) -> T: +def same_type_twice[T](x: T, y: T) -> T: return x if len(str(x)) > len(str(y)) else y -def nested_generic(x: T) -> dict[str, list[T]]: +def nested_generic[T](x: T) -> dict[str, list[T]]: return {"items": [x]} -def variadic_args_func(*args: T) -> T: # Variadic args not supported +def variadic_args_func[T](*args: T) -> T: # Variadic args not supported return args[0] -def variadic_kwargs_func(**kwargs: T) -> T: # Variadic kwargs not supported +def variadic_kwargs_func[T](**kwargs: T) -> T: # Variadic kwargs not supported return next(iter(kwargs.values())) @@ -512,15 +512,15 @@ def test_infer_return_type_success( # Error cases -def unbound_typevar_func(x: T) -> tuple[T, V]: # V not in parameters +def unbound_typevar_func[T](x: T) -> tuple[T, V]: # V not in parameters return (x, "error") -def no_return_annotation(x: T): # No return annotation +def no_return_annotation[T](x: T): # No return annotation return x -def no_param_annotation(x) -> T: # No parameter annotation +def no_param_annotation[T](x) -> T: # type: ignore return x @@ -668,35 +668,35 @@ def test_nested_type_term_error(): nested_type(mock_term) -def sequence_getitem(seq: collections.abc.Sequence[T], index: int) -> T: +def sequence_getitem[T](seq: collections.abc.Sequence[T], index: int) -> T: return seq[index] -def mapping_getitem(mapping: collections.abc.Mapping[K, V], key: K) -> V: +def mapping_getitem[K, V](mapping: collections.abc.Mapping[K, V], key: K) -> V: return mapping[key] -def sequence_mapping_getitem( +def sequence_mapping_getitem[K, V]( seq: collections.abc.Sequence[collections.abc.Mapping[K, V]], index: int, key: K ) -> V: return mapping_getitem(sequence_getitem(seq, index), key) -def mapping_sequence_getitem( +def mapping_sequence_getitem[K, T]( mapping: collections.abc.Mapping[K, collections.abc.Sequence[T]], key: K, index: int ) -> T: return sequence_getitem(mapping_getitem(mapping, key), index) -def sequence_from_pair(a: T, b: T) -> collections.abc.Sequence[T]: +def sequence_from_pair[T](a: T, b: T) -> collections.abc.Sequence[T]: return [a, b] -def mapping_from_pair(a: K, b: V) -> collections.abc.Mapping[K, V]: +def mapping_from_pair[K, V](a: K, b: V) -> collections.abc.Mapping[K, V]: return {a: b} -def sequence_of_mappings( +def sequence_of_mappings[K, V]( key1: K, val1: V, key2: K, val2: V ) -> collections.abc.Sequence[collections.abc.Mapping[K, V]]: """Creates a sequence containing two mappings.""" @@ -705,14 +705,14 @@ def sequence_of_mappings( ) -def mapping_of_sequences( +def mapping_of_sequences[K, T]( key1: K, val1: T, val2: T, key2: K, val3: T, val4: T ) -> collections.abc.Mapping[K, collections.abc.Sequence[T]]: """Creates a mapping where each key maps to a sequence of two values.""" return mapping_from_pair(key1, sequence_from_pair(val1, val2)) -def nested_sequence_mapping( +def nested_sequence_mapping[K, T]( k1: K, v1: T, v2: T, k2: K, v3: T, v4: T ) -> collections.abc.Sequence[collections.abc.Mapping[K, collections.abc.Sequence[T]]]: """Creates a sequence of mappings, where each mapping contains sequences.""" @@ -722,17 +722,17 @@ def nested_sequence_mapping( ) -def get_from_constructed_sequence(a: T, b: T, index: int) -> T: +def get_from_constructed_sequence[T](a: T, b: T, index: int) -> T: """Constructs a sequence from two elements and gets one by index.""" return sequence_getitem(sequence_from_pair(a, b), index) -def get_from_constructed_mapping(key: K, value: V, lookup_key: K) -> V: +def get_from_constructed_mapping[K, V](key: K, value: V, lookup_key: K) -> V: """Constructs a mapping from a key-value pair and looks up the value.""" return mapping_getitem(mapping_from_pair(key, value), lookup_key) -def double_nested_get( +def double_nested_get[K, T]( k1: K, v1: T, v2: T, @@ -750,7 +750,7 @@ def double_nested_get( return sequence_getitem(sequence, inner_index) -def construct_and_extend_sequence( +def construct_and_extend_sequence[T]( a: T, b: T, c: T, d: T ) -> collections.abc.Sequence[collections.abc.Sequence[T]]: """Constructs two sequences and combines them into a sequence of sequences.""" @@ -759,7 +759,7 @@ def construct_and_extend_sequence( return sequence_from_pair(seq1, seq2) -def transform_mapping_values( +def transform_mapping_values[K, T]( key1: K, val1: T, key2: K, val2: T ) -> collections.abc.Mapping[K, collections.abc.Sequence[T]]: """Creates a mapping where each value is wrapped in a sequence.""" @@ -769,7 +769,7 @@ def transform_mapping_values( return mapping_from_pair(key1, sequence_from_pair(val1, val1)) -def call_func( +def call_func[T, V]( func: collections.abc.Callable[[T], V], arg: T, ) -> V: @@ -777,7 +777,7 @@ def call_func( return func(arg) -def call_binary_func( +def call_binary_func[T, U, V]( func: collections.abc.Callable[[T, U], V], arg1: T, arg2: U, @@ -786,7 +786,7 @@ def call_binary_func( return func(arg1, arg2) -def map_sequence( +def map_sequence[T, U]( f: collections.abc.Callable[[T], U], seq: collections.abc.Sequence[T], ) -> collections.abc.Sequence[U]: @@ -794,7 +794,7 @@ def map_sequence( return [call_func(f, x) for x in seq] -def compose_mappings( +def compose_mappings[T, U, V]( f: collections.abc.Callable[[T], U], g: collections.abc.Callable[[U], V], ) -> collections.abc.Callable[[T], V]: @@ -806,7 +806,7 @@ def composed(x: T) -> V: return composed -def compose_binary( +def compose_binary[T, U, V]( f: collections.abc.Callable[[T], U], g: collections.abc.Callable[[U, U], V], ) -> collections.abc.Callable[[T], V]: @@ -818,7 +818,7 @@ def composed(x: T) -> V: return composed -def apply_to_sequence_element( +def apply_to_sequence_element[T, U]( f: collections.abc.Callable[[T], U], seq: collections.abc.Sequence[T], index: int, @@ -828,7 +828,7 @@ def apply_to_sequence_element( return call_func(f, element) -def map_and_get( +def map_and_get[T, U]( f: collections.abc.Callable[[T], U], seq: collections.abc.Sequence[T], index: int, @@ -838,7 +838,7 @@ def map_and_get( return sequence_getitem(mapped_seq, index) -def compose_and_apply( +def compose_and_apply[T, U, V]( f: collections.abc.Callable[[T], U], g: collections.abc.Callable[[U], V], value: T, @@ -848,7 +848,7 @@ def compose_and_apply( return call_func(composed, value) -def double_compose_apply( +def double_compose_apply[T, U, V, W]( f: collections.abc.Callable[[T], U], g: collections.abc.Callable[[U], V], h: collections.abc.Callable[[V], W], @@ -860,7 +860,7 @@ def double_compose_apply( return call_func(fgh, value) -def binary_on_sequence_elements( +def binary_on_sequence_elements[T, U]( f: collections.abc.Callable[[T, T], U], seq: collections.abc.Sequence[T], index1: int, @@ -872,7 +872,7 @@ def binary_on_sequence_elements( return call_binary_func(f, elem1, elem2) -def map_sequence_and_apply_binary( +def map_sequence_and_apply_binary[T, U, V]( f: collections.abc.Callable[[T], U], g: collections.abc.Callable[[U, U], V], seq: collections.abc.Sequence[T], @@ -886,7 +886,7 @@ def map_sequence_and_apply_binary( return call_binary_func(g, elem1, elem2) -def construct_apply_and_get( +def construct_apply_and_get[T, U]( f: collections.abc.Callable[[T], U], a: T, b: T, @@ -897,7 +897,7 @@ def construct_apply_and_get( return apply_to_sequence_element(f, seq, index) -def sequence_function_composition( +def sequence_function_composition[T]( funcs: collections.abc.Sequence[collections.abc.Callable[[T], T]], value: T, ) -> T: @@ -908,7 +908,7 @@ def sequence_function_composition( return result -def map_with_constructed_function( +def map_with_constructed_function[T, U, V]( f: collections.abc.Callable[[T], U], g: collections.abc.Callable[[U], V], seq: collections.abc.Sequence[T], @@ -918,7 +918,7 @@ def map_with_constructed_function( return map_sequence(composed, seq) -def cross_apply_binary( +def cross_apply_binary[T, U, V]( f: collections.abc.Callable[[T, U], V], seq1: collections.abc.Sequence[T], seq2: collections.abc.Sequence[U], @@ -931,7 +931,7 @@ def cross_apply_binary( return call_binary_func(f, elem1, elem2) -def nested_function_application( +def nested_function_application[T, U, V]( outer_f: collections.abc.Callable[[T], collections.abc.Callable[[U], V]], inner_arg: U, outer_arg: T, From 86c5431e97c6478b9e22fa5ce59b0fc5384f989b Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 11 Jul 2025 14:12:54 -0400 Subject: [PATCH 71/76] union --- effectful/internals/unification.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index eb674c61..144e1038 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -65,8 +65,10 @@ if typing.TYPE_CHECKING: GenericAlias = types.GenericAlias + UnionType = types.UnionType else: GenericAlias = types.GenericAlias | typing._GenericAlias + UnionType = types.UnionType | typing._UnionGenericAlias Substitutions = collections.abc.Mapping[ @@ -235,11 +237,11 @@ def _( @unify.register def _( - typ: types.UnionType, - subtyp: type | types.UnionType, + typ: UnionType, + subtyp: type | UnionType, subs: Substitutions = {}, ) -> Substitutions: - if isinstance(subtyp, types.UnionType): + if isinstance(subtyp, UnionType): # If subtyp is a union, try to unify with each argument for arg in typing.get_args(subtyp): subs = unify(typ, arg, subs) @@ -261,7 +263,7 @@ def _( @unify.register def _( typ: GenericAlias, - subtyp: type | types.GenericAlias | typing.TypeVar | types.UnionType, + subtyp: type | types.GenericAlias | typing.TypeVar | UnionType, subs: Substitutions = {}, ) -> Substitutions: if ( @@ -301,12 +303,12 @@ def _( @unify.register def _( typ: type, - subtyp: type | typing.TypeVar | types.UnionType | GenericAlias, + subtyp: type | typing.TypeVar | UnionType | GenericAlias, subs: Substitutions = {}, ) -> Substitutions: if isinstance(subtyp, typing.TypeVar): return unify(subtyp, subs.get(subtyp, typ), {subtyp: typ, **subs}) - elif isinstance(subtyp, types.UnionType): + elif isinstance(subtyp, UnionType): for arg in typing.get_args(subtyp): subs = unify(typ, arg, subs) return subs @@ -325,7 +327,7 @@ def _( @unify.register def _( typ: typing.TypeVar, - subtyp: type | typing.TypeVar | types.UnionType | types.GenericAlias, + subtyp: type | typing.TypeVar | UnionType | types.GenericAlias, subs: Substitutions = {}, ) -> Substitutions: return ( @@ -423,7 +425,7 @@ def _freshen(tp: typing.Any): @functools.singledispatch def nested_type( value, -) -> type | GenericAlias | types.UnionType | types.EllipsisType | None: +) -> type | GenericAlias | UnionType | types.EllipsisType | None: """ Infer the type of a value, handling nested collections with generic parameters. @@ -557,7 +559,7 @@ def nested_type( @nested_type.register -def _(value: type | types.UnionType | GenericAlias | types.EllipsisType): +def _(value: type | UnionType | GenericAlias | types.EllipsisType): return value @@ -744,7 +746,7 @@ def _(typ: GenericAlias): @freetypevars.register -def _(typ: types.UnionType): +def _(typ: UnionType): return freetypevars(typing.get_args(typ)) @@ -754,7 +756,7 @@ def substitute( ) -> ( type | types.GenericAlias - | types.UnionType + | UnionType | None | typing.TypeVar | typing.ParamSpec @@ -858,7 +860,7 @@ def _(typ: GenericAlias, subs: Substitutions): @substitute.register -def _(typ: types.UnionType, subs: Substitutions): +def _(typ: UnionType, subs: Substitutions): ts: tuple = substitute(typing.get_args(typ), subs) # type: ignore tp, ts = ts[0], ts[1:] for arg in ts: From 3811b83382125404851a41661d010edb9264bf88 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 11 Jul 2025 14:19:38 -0400 Subject: [PATCH 72/76] try truncating?? --- effectful/internals/unification.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index 144e1038..d4bb346f 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -407,14 +407,14 @@ def _freshen(tp: typing.Any): tp, { fv: typing.TypeVar( - name=f"{fv.__name__[:100]}_{random.randint(0, 1 << 32)}", + name=f"{fv.__name__[:60]}_{random.randint(0, int(1e7))}", bound=fv.__bound__, covariant=fv.__covariant__, contravariant=fv.__contravariant__, ) if isinstance(fv, typing.TypeVar) else typing.ParamSpec( - name=f"{fv.__name__[:100]}_{random.randint(0, 1 << 32)}" + name=f"{fv.__name__[:60]}_{random.randint(0, int(1e7))}", ) for fv in freetypevars(tp) if isinstance(fv, typing.TypeVar | typing.ParamSpec) From 4263617610a9d0a3c36e6e187cd3f466aaa8dbfc Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 11 Jul 2025 18:34:06 -0400 Subject: [PATCH 73/76] tweaks --- effectful/internals/unification.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/effectful/internals/unification.py b/effectful/internals/unification.py index d4bb346f..b709c6a3 100644 --- a/effectful/internals/unification.py +++ b/effectful/internals/unification.py @@ -580,9 +580,11 @@ def _(value: collections.abc.Callable): @nested_type.register def _(value: collections.abc.Mapping): - from effectful.ops.types import Interpretation + from effectful.ops.types import Interpretation, Term - if type(value) is Interpretation: # More specific check + if isinstance(value, Term): + raise TypeError(f"Terms should not appear in nested_type, but got {value}") + elif isinstance(value, Interpretation): # More specific check return Interpretation elif len(value) == 0: return type(value) @@ -735,7 +737,7 @@ def _(typ: str | bytes): @freetypevars.register -def _(typ: collections.abc.Mapping): +def _(typ: dict): assert all(isinstance(k, str) for k in typ.keys()), "Mapping keys must be strings" return freetypevars(typ.values()) @@ -844,7 +846,7 @@ def _(typ: tuple, subs: Substitutions): @substitute.register -def _(typ: collections.abc.Mapping, subs: Substitutions): +def _(typ: dict, subs: Substitutions): assert all(isinstance(k, str) for k in typ.keys()), "Mapping keys must be strings" return {k: substitute(v, subs) for k, v in typ.items()} From c8bb57ce0b8f75d0ef8b209efe8d17dbfe941148 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 11 Jul 2025 18:48:49 -0400 Subject: [PATCH 74/76] format, add set term --- effectful/ops/syntax.py | 33 ++++++++++++++------------------- tests/test_ops_syntax.py | 8 -------- 2 files changed, 14 insertions(+), 27 deletions(-) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 05013630..7901c40e 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -893,18 +893,6 @@ def defterm[T](__dispatch: Callable[[type], Callable[[T], Expr[T]]], value: T): return __dispatch(type(value))(value) -def _map_structure_and_keys(func, structure): - def _map_value(value): - if isinstance(value, dict): - return {func(k): v for k, v in value.items()} - elif not tree.is_nested(value): - return func(value) - else: - return value - - return tree.traverse(_map_value, structure, top_down=False) - - @_CustomSingleDispatchCallable def defdata[T]( __dispatch: Callable[[type], Callable[..., Expr[T]]], @@ -1121,7 +1109,9 @@ def defstream[S, T, A, B]( @defdata.register(collections.abc.Iterable) -class _IterableTerm[T](_BaseTerm[collections.abc.Iterable[T]], collections.abc.Iterable[T]): +class _IterableTerm[T]( + _BaseTerm[collections.abc.Iterable[T]], collections.abc.Iterable[T] +): @defop def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: if not isinstance(self, Term): @@ -1140,12 +1130,8 @@ def __next__(self: collections.abc.Iterator[T]) -> T: raise NotImplementedError -iter_ = _IterableTerm.__iter__ -next_ = _IteratorTerm.__next__ - - @defdata.register(collections.abc.Collection) -class _CollectionTerm[T](_IterableTerm[T]): +class _CollectionTerm[T](_IterableTerm[T], collections.abc.Collection[T]): @defop def __contains__(self: collections.abc.Collection[T], item: T) -> bool: if not isinstance(self, Term) and not isinstance(item, Term): @@ -1161,10 +1147,15 @@ def __len__(self: collections.abc.Collection[T]) -> int: raise NotImplementedError +@defdata.register(collections.abc.Set) +class _SetTerm[T](_CollectionTerm[T], collections.abc.Set[T]): + pass + + @defdata.register(collections.abc.Sequence) class _SequenceTerm[T](_CollectionTerm[T], collections.abc.Sequence[T]): @defop - def __getitem__(self: collections.abc.Sequence[T], index: int) -> T: + def __getitem__(self: collections.abc.Sequence[T], index: int | slice) -> T: if not isinstance(self, Term) and not isinstance(index, Term): return self[index] else: @@ -1181,6 +1172,10 @@ def __getitem__(self: collections.abc.Mapping[S, V], key: S) -> V: raise NotImplementedError +iter_ = _IterableTerm.__iter__ +next_ = _IteratorTerm.__next__ + + def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool: """Syntactic equality, ignoring the interpretation of the terms. diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index 37ee9c1b..b6a0dac0 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -11,7 +11,6 @@ from effectful.ops.syntax import ( Scoped, _CustomSingleDispatchCallable, - _map_structure_and_keys, deffn, defop, defstream, @@ -112,13 +111,6 @@ def f(x): assert f_op != ff_op -def test_map_structure_and_keys(): - s = {1: 2, 3: [4, 5, (6, {7: 8})]} - expected = {2: 3, 4: [5, 6, (7, {8: 9})]} - actual = _map_structure_and_keys(lambda x: x + 1, s) - assert actual == expected - - def test_scoped_collections(): """Test that Scoped annotations work with tree-structured collections containing Operations.""" From 7add44111c9394daaa5b2509bc90bfcc61aadc5c Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 11 Jul 2025 18:49:41 -0400 Subject: [PATCH 75/76] ops --- effectful/ops/syntax.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 7901c40e..4c442aec 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -1174,6 +1174,10 @@ def __getitem__(self: collections.abc.Mapping[S, V], key: S) -> V: iter_ = _IterableTerm.__iter__ next_ = _IteratorTerm.__next__ +len_ = _CollectionTerm.__len__ +contains = _CollectionTerm.__contains__ +sequence_getitem = _SequenceTerm.__getitem__ +mapping_getitem = _MappingTerm.__getitem__ def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool: From 507ee1f98e0c28eb4f6dd701aff21f2d4fe16247 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 12 Jul 2025 23:04:13 -0400 Subject: [PATCH 76/76] stash work on collections --- effectful/handlers/numbers.py | 272 ++----------- effectful/ops/syntax.py | 696 +++++++++++++++++++++++++++------ tests/test_handlers_numbers.py | 1 - tests/test_ops_syntax.py | 20 +- 4 files changed, 616 insertions(+), 373 deletions(-) diff --git a/effectful/handlers/numbers.py b/effectful/handlers/numbers.py index a1f5ccd7..f641f47e 100644 --- a/effectful/handlers/numbers.py +++ b/effectful/handlers/numbers.py @@ -3,251 +3,29 @@ """ import numbers -import operator -from typing import Any -from effectful.ops.syntax import defdata, defop, syntactic_eq -from effectful.ops.types import Expr, Operation, Term - - -@defdata.register(numbers.Number) -@numbers.Number.register -class _NumberTerm(Term[numbers.Number]): - def __init__( - self, op: Operation[..., numbers.Number], *args: Expr, **kwargs: Expr - ) -> None: - self._op = op - self._args = args - self._kwargs = kwargs - - @property - def op(self) -> Operation[..., numbers.Number]: - return self._op - - @property - def args(self) -> tuple: - return self._args - - @property - def kwargs(self) -> dict: - return self._kwargs - - def __hash__(self): - return hash((self.op, tuple(self.args), tuple(self.kwargs.items()))) - - -# Complex specific methods -@defop -def eq[T_Number: numbers.Number](x: T_Number, y: T_Number) -> bool: - if not any(isinstance(a, Term) for a in (x, y)): - return operator.eq(x, y) - else: - return syntactic_eq(x, y) - - -def _wrap_cmp(op): - def _wrapped_op[T_Number: numbers.Number](x: T_Number, y: T_Number) -> bool: - if not any(isinstance(a, Term) for a in (x, y)): - return op(x, y) - else: - raise NotImplementedError - - _wrapped_op.__name__ = op.__name__ - return _wrapped_op - - -def _wrap_binop(op): - def _wrapped_op[T_Number: numbers.Number](x: T_Number, y: T_Number) -> T_Number: - if not any(isinstance(a, Term) for a in (x, y)): - return op(x, y) - else: - raise NotImplementedError - - _wrapped_op.__name__ = op.__name__ - return _wrapped_op - - -def _wrap_unop(op): - def _wrapped_op[T_Number: numbers.Number](x: T_Number) -> T_Number: - if not isinstance(x, Term): - return op(x) - else: - raise NotImplementedError - - _wrapped_op.__name__ = op.__name__ - return _wrapped_op - - -add = defop(_wrap_binop(operator.add)) -neg = defop(_wrap_unop(operator.neg)) -pos = defop(_wrap_unop(operator.pos)) -sub = defop(_wrap_binop(operator.sub)) -mul = defop(_wrap_binop(operator.mul)) -truediv = defop(_wrap_binop(operator.truediv)) -pow = defop(_wrap_binop(operator.pow)) -abs = defop(_wrap_unop(operator.abs)) - - -@defdata.register(numbers.Complex) -@numbers.Complex.register -class _ComplexTerm(_NumberTerm, Term[numbers.Complex]): - def __bool__(self) -> bool: - raise ValueError("Cannot convert term to bool") - - def __add__(self, other: Any) -> numbers.Real: - return add(self, other) - - def __radd__(self, other: Any) -> numbers.Real: - return add(other, self) - - def __neg__(self): - return neg(self) - - def __pos__(self): - return pos(self) - - def __sub__(self, other: Any) -> numbers.Real: - return sub(self, other) - - def __rsub__(self, other: Any) -> numbers.Real: - return sub(other, self) - - def __mul__(self, other: Any) -> numbers.Real: - return mul(self, other) - - def __rmul__(self, other: Any) -> numbers.Real: - return mul(other, self) - - def __truediv__(self, other: Any) -> numbers.Real: - return truediv(self, other) - - def __rtruediv__(self, other: Any) -> numbers.Real: - return truediv(other, self) - - def __pow__(self, other: Any) -> numbers.Real: - return pow(self, other) - - def __rpow__(self, other: Any) -> numbers.Real: - return pow(other, self) - - def __abs__(self) -> numbers.Real: - return abs(self) - - def __eq__(self, other: Any) -> bool: - return eq(self, other) - - -# Real specific methods -floordiv = defop(_wrap_binop(operator.floordiv)) -mod = defop(_wrap_binop(operator.mod)) -lt = defop(_wrap_cmp(operator.lt)) -le = defop(_wrap_cmp(operator.le)) -gt = defop(_wrap_cmp(operator.gt)) -ge = defop(_wrap_cmp(operator.ge)) - - -@defdata.register(numbers.Real) -@numbers.Real.register -class _RealTerm(_ComplexTerm, Term[numbers.Real]): - # Real specific methods - def __float__(self) -> float: - raise ValueError("Cannot convert term to float") - - def __trunc__(self) -> numbers.Integral: - raise NotImplementedError - - def __floor__(self) -> numbers.Integral: - raise NotImplementedError - - def __ceil__(self) -> numbers.Integral: - raise NotImplementedError - - def __round__(self, ndigits=None) -> numbers.Integral: - raise NotImplementedError - - def __floordiv__(self, other): - return floordiv(self, other) - - def __rfloordiv__(self, other): - return floordiv(other, self) - - def __mod__(self, other): - return mod(self, other) - - def __rmod__(self, other): - return mod(other, self) - - def __lt__(self, other): - return lt(self, other) - - def __le__(self, other): - return le(self, other) - - -@defdata.register(numbers.Rational) -@numbers.Rational.register -class _RationalTerm(_RealTerm, Term[numbers.Rational]): - @property - def numerator(self): - raise NotImplementedError - - @property - def denominator(self): - raise NotImplementedError - - -# Integral specific methods -index = defop(_wrap_unop(operator.index)) -lshift = defop(_wrap_binop(operator.lshift)) -rshift = defop(_wrap_binop(operator.rshift)) -and_ = defop(_wrap_binop(operator.and_)) -xor = defop(_wrap_binop(operator.xor)) -or_ = defop(_wrap_binop(operator.or_)) -invert = defop(_wrap_unop(operator.invert)) - - -@defdata.register(numbers.Integral) -@numbers.Integral.register -class _IntegralTerm(_RationalTerm, Term[numbers.Integral]): - # Integral specific methods - def __int__(self) -> int: - raise ValueError("Cannot convert term to int") - - def __index__(self) -> numbers.Integral: - return index(self) - - def __pow__(self, exponent: Any, modulus=None) -> numbers.Integral: - return pow(self, exponent) - - def __lshift__(self, other): - return lshift(self, other) - - def __rlshift__(self, other): - return lshift(other, self) - - def __rshift__(self, other): - return rshift(self, other) - - def __rrshift__(self, other): - return rshift(other, self) - - def __and__(self, other): - return and_(self, other) - - def __rand__(self, other): - return and_(other, self) - - def __xor__(self, other): - return xor(self, other) - - def __rxor__(self, other): - return xor(other, self) - - def __or__(self, other): - return or_(self, other) - - def __ror__(self, other): - return or_(other, self) - - def __invert__(self): - return invert(self) +from effectful.ops.syntax import defdata + + +add = defdata.dispatch(numbers.Number).__add__ +neg = defdata.dispatch(numbers.Number).__neg__ +pos = defdata.dispatch(numbers.Number).__pos__ +sub = defdata.dispatch(numbers.Number).__sub__ +mul = defdata.dispatch(numbers.Number).__mul__ +truediv = defdata.dispatch(numbers.Number).__truediv__ +pow = defdata.dispatch(numbers.Number).__pow__ +abs = defdata.dispatch(numbers.Number).__abs__ +floordiv = defdata.dispatch(numbers.Number).__floordiv__ +mod = defdata.dispatch(numbers.Number).__mod__ +eq = defdata.dispatch(numbers.Number).__eq__ +lt = defdata.dispatch(numbers.Number).__lt__ +le = defdata.dispatch(numbers.Number).__le__ +gt = defdata.dispatch(numbers.Number).__gt__ +ge = defdata.dispatch(numbers.Number).__ge__ +index = defdata.dispatch(numbers.Number).__index__ +lshift = defdata.dispatch(numbers.Number).__lshift__ +rshift = defdata.dispatch(numbers.Number).__rshift__ +and_ = defdata.dispatch(numbers.Number).__and__ +xor = defdata.dispatch(numbers.Number).__xor__ +or_ = defdata.dispatch(numbers.Number).__or__ +invert = defdata.dispatch(numbers.Number).__invert__ diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index ac1d7fd6..1f3711b7 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -2,6 +2,8 @@ import dataclasses import functools import inspect +import numbers +import operator import random import types import typing @@ -981,48 +983,6 @@ def _[T](value: T) -> T: return value -@defdata.register(object) -class _BaseTerm[T](Term[T]): - _op: Operation[..., T] - _args: collections.abc.Sequence[Expr] - _kwargs: collections.abc.Mapping[str, Expr] - - def __init__( - self, - op: Operation[..., T], - *args: Expr, - **kwargs: Expr, - ): - self._op = op - self._args = args - self._kwargs = kwargs - - def __eq__(self, other) -> bool: - from effectful.ops.syntax import syntactic_eq - - return syntactic_eq(self, other) - - @property - def op(self): - return self._op - - @property - def args(self): - return self._args - - @property - def kwargs(self): - return self._kwargs - - -@defdata.register(collections.abc.Callable) -class _CallableTerm[**P, T](_BaseTerm[collections.abc.Callable[P, T]]): - def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]: - from effectful.ops.semantics import call - - return call(self, *args, **kwargs) # type: ignore - - def trace[**P, T](value: Callable[P, T]) -> Callable[P, T]: """Convert a callable to a term by calling it with appropriately typed free variables. @@ -1087,78 +1047,6 @@ def defstream[S, T, A, B]( raise NotImplementedError -@defdata.register(collections.abc.Iterable) -class _IterableTerm[T]( - _BaseTerm[collections.abc.Iterable[T]], collections.abc.Iterable[T] -): - @defop - def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: - if not isinstance(self, Term): - return iter(self) - else: - raise NotImplementedError - - -@defdata.register(collections.abc.Iterator) -class _IteratorTerm[T](_IterableTerm[T], collections.abc.Iterator[T]): - @defop - def __next__(self: collections.abc.Iterator[T]) -> T: - if not isinstance(self, Term): - return next(self) - else: - raise NotImplementedError - - -@defdata.register(collections.abc.Collection) -class _CollectionTerm[T](_IterableTerm[T], collections.abc.Collection[T]): - @defop - def __contains__(self: collections.abc.Collection[T], item: T) -> bool: - if not isinstance(self, Term) and not isinstance(item, Term): - return item in self - else: - raise NotImplementedError - - @defop - def __len__(self: collections.abc.Collection[T]) -> int: - if not isinstance(self, Term): - return len(self) - else: - raise NotImplementedError - - -@defdata.register(collections.abc.Set) -class _SetTerm[T](_CollectionTerm[T], collections.abc.Set[T]): - pass - - -@defdata.register(collections.abc.Sequence) -class _SequenceTerm[T](_CollectionTerm[T], collections.abc.Sequence[T]): - @defop - def __getitem__(self: collections.abc.Sequence[T], index: int | slice) -> T: - if not isinstance(self, Term) and not isinstance(index, Term): - return self[index] - else: - raise NotImplementedError - - -@defdata.register(collections.abc.Mapping) -class _MappingTerm[S, V](_CollectionTerm[S], collections.abc.Mapping[S, V]): - @defop - def __getitem__(self: collections.abc.Mapping[S, V], key: S) -> V: - if not isinstance(self, Term) and not isinstance(key, Term): - return self[key] - else: - raise NotImplementedError - - -iter_ = _IterableTerm.__iter__ -next_ = _IteratorTerm.__next__ -len_ = _CollectionTerm.__len__ -contains = _CollectionTerm.__contains__ -sequence_getitem = _SequenceTerm.__getitem__ -mapping_getitem = _MappingTerm.__getitem__ - - def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool: """Syntactic equality, ignoring the interpretation of the terms. @@ -1308,3 +1196,583 @@ def implements[**P, V](op: Operation[P, V]): """ return _ImplementedOperation(op) + + +@defdata.register(object) +class _BaseTerm[T](Term[T]): + _op: Operation[..., T] + _args: collections.abc.Sequence[Expr] + _kwargs: collections.abc.Mapping[str, Expr] + + def __init__( + self, + op: Operation[..., T], + *args: Expr, + **kwargs: Expr, + ): + self._op = op + self._args = args + self._kwargs = kwargs + + def __eq__(self, other) -> bool: + from effectful.ops.syntax import syntactic_eq + + return syntactic_eq(self, other) + + @property + def op(self): + return self._op + + @property + def args(self): + return self._args + + @property + def kwargs(self): + return self._kwargs + + +@defdata.register(collections.abc.Callable) +class _CallableTerm[**P, T](_BaseTerm[collections.abc.Callable[P, T]]): + def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]: + from effectful.ops.semantics import call + + return call(self, *args, **kwargs) # type: ignore + + +@defdata.register(collections.abc.Iterable) +class _IterableTerm[T]( + _BaseTerm[collections.abc.Iterable[T]], collections.abc.Iterable[T] +): + @defop + def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: + if not isinstance(self, Term): + return self.__iter__() + else: + raise NotImplementedError + + +@defdata.register(collections.abc.Iterator) +class _IteratorTerm[T](_IterableTerm[T], collections.abc.Iterator[T]): + @defop + def __next__(self: collections.abc.Iterator[T]) -> T: + if not isinstance(self, Term): + return self.__next__() + else: + raise NotImplementedError + + +# @defdata.register(collections.abc.Reversible) +class _ReversibleTerm[T](_IterableTerm[T], collections.abc.Reversible[T]): + @defop + def __reversed__(self: collections.abc.Reversible[T]) -> collections.abc.Iterator[T]: + if not isinstance(self, Term): + return self.__reversed__() + else: + raise NotImplementedError + + +@defdata.register(collections.abc.Collection) +class _CollectionTerm[T](_IterableTerm[T], collections.abc.Collection[T]): + @defop + def __len__(self: collections.abc.Sized) -> int: + if not isinstance(self, Term): + return self.__len__() + else: + raise NotImplementedError + + @defop + def __contains__(self: collections.abc.Container[T], item: T) -> bool: + if not isinstance(self, Term) and not isinstance(item, Term): + return self.__contains__(item) + else: + raise NotImplementedError + + +@defdata.register(collections.abc.Sequence) +class _SequenceTerm[T](_CollectionTerm[T], _ReversibleTerm[T], collections.abc.Sequence[T]): + @defop + def __getitem__(self: collections.abc.Sequence[T], index: int | slice) -> T: + if not isinstance(self, Term) and not isinstance(index, Term): + return self.__getitem__(index) + else: + raise NotImplementedError + + @defop + def index( + self: collections.abc.Sequence[T], value: T, start: int = 0, stop: int | None = None + ) -> int: + if not isinstance(self, Term) and not isinstance(value, Term): + return self.index(value, start, stop) + else: + raise NotImplementedError + + @defop + def count(self: collections.abc.Sequence[T], value: T) -> int: + if not isinstance(self, Term) and not isinstance(value, Term): + return self.count(value) + else: + raise NotImplementedError + + +@defdata.register(collections.abc.Set) +@functools.total_ordering +class _SetTerm[T](_CollectionTerm[T], collections.abc.Set[T]): + @defop + def __eq__(self: collections.abc.Set[T], other: collections.abc.Set[T]) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__eq__(other) + else: + raise NotImplementedError + + @defop + def __lt__(self: collections.abc.Set[T], other: collections.abc.Set[T]) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__lt__(other) + else: + raise NotImplementedError + + @defop + def __and__[V](self: collections.abc.Set[T], other: collections.abc.Set[V]) -> collections.abc.Set[T | V]: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__and__(other) + else: + raise NotImplementedError + + @defop + def __or__[V](self: collections.abc.Set[T], other: collections.abc.Set[V]) -> collections.abc.Set[T | V]: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__or__(other) + else: + raise NotImplementedError + + @defop + def __xor__[V](self: collections.abc.Set[T], other: collections.abc.Set[V]) -> collections.abc.Set[T | V]: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__xor__(other) + else: + raise NotImplementedError + + def __rxor__[V](self: collections.abc.Set[T], other: collections.abc.Set[V]) -> collections.abc.Set[T | V]: + """Right-hand side XOR operator for sets.""" + return type(self).__xor__(other, self) + + @defop + def __sub__[V](self: collections.abc.Set[T], other: collections.abc.Set[V]) -> collections.abc.Set[T]: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__sub__(other) + else: + raise NotImplementedError + + def __rsub__[V](self: collections.abc.Set[T], other: collections.abc.Set[V]) -> collections.abc.Set[T]: + """Right-hand side subtraction operator for sets.""" + return type(self).__sub__(other, self) + + @defop + def isdisjoint[V](self: collections.abc.Set[T], other: collections.abc.Iterable[V]) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.isdisjoint(other) + else: + raise NotImplementedError + + +@defdata.register(collections.abc.Mapping) +class _MappingTerm[S, V](_CollectionTerm[S], collections.abc.Mapping[S, V]): + @defop + def __getitem__(self: collections.abc.Mapping[S, V], key: S) -> V: + if not isinstance(self, Term) and not isinstance(key, Term): + return self.__getitem__(key) + else: + raise NotImplementedError + + @defop + def __eq__(self: collections.abc.Mapping[S, V], other: collections.abc.Mapping[S, V]) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__eq__(other) + else: + raise NotImplementedError + + @defop + def keys(self: collections.abc.Mapping[S, V]) -> collections.abc.Set[S]: + if not isinstance(self, Term): + return self.keys() + else: + raise NotImplementedError + + @defop + def values(self: collections.abc.Mapping[S, V]) -> collections.abc.Collection[V]: + if not isinstance(self, Term): + return self.values() + else: + raise NotImplementedError + + @defop + def items(self: collections.abc.Mapping[S, V]) -> collections.abc.Set[tuple[S, V]]: + if not isinstance(self, Term): + return self.items() + else: + raise NotImplementedError + + +@defdata.register(numbers.Number) +@functools.total_ordering +class _NumberTerm[T: numbers.Number](_BaseTerm[T], numbers.Number): + def __hash__(self): + return id(self) + + def __complex__(self) -> complex: + raise ValueError("Cannot convert term to complex number") + + def __float__(self) -> float: + raise ValueError("Cannot convert term to float") + + def __int__(self) -> int: + raise ValueError("Cannot convert term to int") + + def __bool__(self) -> bool: + raise ValueError("Cannot convert term to bool") + + @defop + @property + def real(self: numbers.Complex) -> float: + if not isinstance(self, Term): + return self.real + else: + raise NotImplementedError + + @defop + @property + def imag(self: numbers.Complex) -> float: + if not isinstance(self, Term): + return self.imag + else: + raise NotImplementedError + + @defop + def conjugate(self: T) -> T: + if not isinstance(self, Term): + return self.conjugate() + else: + raise NotImplementedError + + @defop + @property + def numerator(self: numbers.Rational) -> int: + if not isinstance(self, Term): + return self.numerator + else: + raise NotImplementedError + + @defop + @property + def denominator(self: numbers.Rational) -> int: + if not isinstance(self, Term): + return self.denominator + else: + raise NotImplementedError + + @defop + def __abs__(self: numbers.Complex) -> float: + """Return the absolute value of the term.""" + if not isinstance(self, Term): + return self.__abs__() + else: + raise NotImplementedError + + @defop + def __neg__(self: T) -> T: + if not isinstance(self, Term): + return self.__neg__() + else: + raise NotImplementedError + + @defop + def __pos__(self: T) -> T: + if not isinstance(self, Term): + return self.__pos__() + else: + raise NotImplementedError + + @defop + def __trunc__(self: numbers.Real) -> int: + if not isinstance(self, Term): + return self.__trunc__() + else: + raise NotImplementedError + + @defop + def __floor__(self: numbers.Real) -> int: + if not isinstance(self, Term): + return self.__floor__() + else: + raise NotImplementedError + + @defop + def __ceil__(self: numbers.Real) -> int: + if not isinstance(self, Term): + return self.__ceil__() + else: + raise NotImplementedError + + @defop + def __round__(self: numbers.Real, ndigits: int | None = None) -> int | float: + if not isinstance(self, Term) and not isinstance(ndigits, Term): + return self.__round__(ndigits) + else: + raise NotImplementedError + + @defop + def __invert__(self: numbers.Integral) -> int: + if not isinstance(self, Term): + return self.__invert__() + else: + raise NotImplementedError + + @defop + def __index__(self: numbers.Integral) -> int: + if not isinstance(self, Term): + return self.__index__() + else: + raise NotImplementedError + + @defop + def __eq__(self: numbers.Complex, other: numbers.Complex) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__eq__(other) + else: + return syntactic_eq(self, other) + + @defop + def __lt__(self: numbers.Real, other: numbers.Real) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__lt__(other) + else: + raise NotImplementedError + + @defop + def __add__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__add__(self, other) + else: + raise NotImplementedError + + def __radd__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__add__(self) + elif not isinstance(other, Term): + return type(self).__add__(other, self) + else: + return NotImplemented + + @defop + def __sub__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__sub__(self, other) + else: + raise NotImplementedError + + def __rsub__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__sub__(self) + elif not isinstance(other, Term): + return type(self).__sub__(other, self) + else: + return NotImplemented + + @defop + def __mul__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__mul__(self, other) + else: + raise NotImplementedError + + def __rmul__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__mul__(self) + elif not isinstance(other, Term): + return type(self).__mul__(other, self) + else: + return NotImplemented + + @defop + def __truediv__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__truediv__(self, other) + else: + raise NotImplementedError + + def __rtruediv__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__truediv__(self) + elif not isinstance(other, Term): + return type(self).__truediv__(other, self) + else: + return NotImplemented + + @defop + def __floordiv__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__floordiv__(self, other) + else: + raise NotImplementedError + + def __rfloordiv__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__floordiv__(self) + elif not isinstance(other, Term): + return type(self).__floordiv__(other, self) + else: + return NotImplemented + + @defop + def __mod__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__mod__(self, other) + else: + raise NotImplementedError + + def __rmod__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__mod__(self) + elif not isinstance(other, Term): + return type(self).__mod__(other, self) + else: + return NotImplemented + + @defop + def __pow__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__pow__(self, other) + else: + raise NotImplementedError + + def __rpow__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__pow__(self) + elif not isinstance(other, Term): + return type(self).__pow__(other, self) + else: + return NotImplemented + + @defop + def __lshift__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__lshift__(self, other) + else: + raise NotImplementedError + + def __rlshift__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__lshift__(self) + elif not isinstance(other, Term): + return type(self).__lshift__(other, self) + else: + return NotImplemented + + @defop + def __rshift__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__rshift__(self, other) + else: + raise NotImplementedError + + def __rrshift__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__rshift__(self) + elif not isinstance(other, Term): + return type(self).__rshift__(other, self) + else: + return NotImplemented + + @defop + def __and__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__and__(self, other) + else: + raise NotImplementedError + + def __rand__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__and__(self) + elif not isinstance(other, Term): + return type(self).__and__(other, self) + else: + return NotImplemented + + @defop + def __xor__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__xor__(self, other) + else: + raise NotImplementedError + + def __rxor__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__xor__(self) + elif not isinstance(other, Term): + return type(self).__xor__(other, self) + else: + return NotImplemented + + @defop + def __or__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__or__(self, other) + else: + raise NotImplementedError + + def __ror__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__or__(self) + elif not isinstance(other, Term): + return type(self).__or__(other, self) + else: + return NotImplemented + + +# # numbers.Complex polymorphic binary methods +# __add__, __radd__ = _wrapper_generator(operator.__add__) +# __sub__, __rsub__ = _wrapper_generator(operator.__sub__) +# __mul__, __rmul__ = _wrapper_generator(operator.__mul__) +# __truediv__, __rtruediv__ = _wrapper_generator(operator.__truediv__) +# __pow__, __rpow__ = _wrapper_generator(operator.__pow__) + +# # numbers.Real polymorphic binary methods +# __floordiv__, __rfloordiv__ = _wrapper_generator(operator.__floordiv__) +# __mod__, __rmod__ = _wrapper_generator(operator.__mod__) + +# # numbers.Integral polymorphic binary methods +# __lshift__, __rlshift__ = _wrapper_generator(operator.__lshift__) +# __rshift__, __rrshift__ = _wrapper_generator(operator.__rshift__) +# __and__, __rand__ = _wrapper_generator(operator.__and__) +# __xor__, __rxor__ = _wrapper_generator(operator.__xor__) +# __or__, __ror__ = _wrapper_generator(operator.__or__) + + +@defdata.register(numbers.Complex) +@numbers.Complex.register +class _ComplexTerm[T: numbers.Complex](_NumberTerm[T]): + pass + + +@defdata.register(numbers.Real) +@numbers.Real.register +class _RealTerm[T: numbers.Real](_ComplexTerm[T]): + pass + + +@defdata.register(numbers.Rational) +@numbers.Rational.register +class _RationalTerm[T: numbers.Rational](_RealTerm[T]): + pass + + +@defdata.register(numbers.Integral) +@numbers.Integral.register +class _IntegralTerm[T: numbers.Integral](_RationalTerm[T]): + pass + + +@defdata.register(bool) +class _BoolTerm[T: bool](_IntegralTerm[T]): + pass \ No newline at end of file diff --git a/tests/test_handlers_numbers.py b/tests/test_handlers_numbers.py index 7635d548..5d571d38 100644 --- a/tests/test_handlers_numbers.py +++ b/tests/test_handlers_numbers.py @@ -248,7 +248,6 @@ def test_evaluate_2(): y = defop(int, name="y") t = x() + y() assert isinstance(t, Term) - assert t.op.__name__ == "add" with handler({x: lambda: 1, y: lambda: 3}): assert evaluate(t) == 4 diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index b6a0dac0..d35b5bcc 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -15,8 +15,6 @@ defop, defstream, defterm, - iter_, - next_, ) from effectful.ops.types import Operation, Term @@ -198,11 +196,11 @@ def test_term_str(): assert str(x1) == str(x2) == str(x3) == "x" assert repr(x1) != repr(x2) != repr(x3) - assert str(x1() + x2()) == "add(x(), x!1())" - assert str(x1() + x1()) == "add(x(), x())" - assert str(deffn(x1() + x1(), x1)) == "deffn(add(x(), x()), x)" - assert str(deffn(x1() + x1(), x2)) == "deffn(add(x(), x()), x!1)" - assert str(deffn(x1() + x2(), x1)) == "deffn(add(x(), x!1()), x)" + assert str(x1() + x2()) == "__add__(x(), x!1())" + assert str(x1() + x1()) == "__add__(x(), x())" + assert str(deffn(x1() + x1(), x1)) == "deffn(__add__(x(), x()), x)" + assert str(deffn(x1() + x1(), x2)) == "deffn(__add__(x(), x()), x!1)" + assert str(deffn(x1() + x2(), x1)) == "deffn(__add__(x(), x!1()), x)" def test_defop_singledispatch(): @@ -507,13 +505,13 @@ def cons_iterable(*args: int) -> Iterable[int]: assert isinstance(tm_iter, Term) assert isinstance(tm_iter, Iterator) assert issubclass(typeof(tm_iter), Iterator) - assert tm_iter.op is iter_ + assert tm_iter.op is type(tm).__iter__ tm_iter_next = next(tm_iter) assert isinstance(tm_iter_next, Term) # assert isinstance(tm_iter_next, numbers.Number) # TODO # assert issubclass(typeof(tm_iter_next), numbers.Number) - assert tm_iter_next.op is next_ + assert tm_iter_next.op is type(tm_iter).__next__ assert list(tm.args) == [1, 2, 3] @@ -535,13 +533,13 @@ def test_defstream_1(): assert isinstance(tm_iter, Term) assert isinstance(tm_iter, Iterator) assert issubclass(typeof(tm_iter), Iterator) - assert tm_iter.op is iter_ + assert tm_iter.op is type(tm).__iter__ tm_iter_next = next(tm_iter) assert isinstance(tm_iter_next, Term) # assert isinstance(tm_iter_next, numbers.Number) # TODO # assert issubclass(typeof(tm_iter_next), numbers.Number) - assert tm_iter_next.op is next_ + assert tm_iter_next.op is type(tm_iter).__next__ def test_defterm_sequence():