diff --git a/effectful/handlers/numbers.py b/effectful/handlers/numbers.py index 57a0a0fb..874a3156 100644 --- a/effectful/handlers/numbers.py +++ b/effectful/handlers/numbers.py @@ -3,244 +3,28 @@ """ import numbers -import operator -from typing import Any -from effectful.ops.syntax import defdata, defop -from effectful.ops.types import Expr, Operation, Term - - -@defdata.register(numbers.Number) -@numbers.Number.register -class _NumberTerm(Term[numbers.Number]): - def __init__( - self, op: Operation[..., numbers.Number], *args: Expr, **kwargs: Expr - ) -> None: - self._op = op - self._args = args - self._kwargs = kwargs - - @property - def op(self) -> Operation[..., numbers.Number]: - return self._op - - @property - def args(self) -> tuple: - return self._args - - @property - def kwargs(self) -> dict: - return self._kwargs - - def __hash__(self): - return hash((self.op, tuple(self.args), tuple(self.kwargs.items()))) - - -# Complex specific methods -def _wrap_cmp(op): - def _wrapped_op[T_Number: numbers.Number](x: T_Number, y: T_Number) -> bool: - if not any(isinstance(a, Term) for a in (x, y)): - return op(x, y) - else: - raise NotImplementedError - - _wrapped_op.__name__ = op.__name__ - return _wrapped_op - - -def _wrap_binop(op): - def _wrapped_op[T_Number: numbers.Number](x: T_Number, y: T_Number) -> T_Number: - if not any(isinstance(a, Term) for a in (x, y)): - return op(x, y) - else: - raise NotImplementedError - - _wrapped_op.__name__ = op.__name__ - return _wrapped_op - - -def _wrap_unop(op): - def _wrapped_op[T_Number: numbers.Number](x: T_Number) -> T_Number: - if not isinstance(x, Term): - return op(x) - else: - raise NotImplementedError - - _wrapped_op.__name__ = op.__name__ - return _wrapped_op - - -add = defop(_wrap_binop(operator.add)) -neg = defop(_wrap_unop(operator.neg)) -pos = defop(_wrap_unop(operator.pos)) -sub = defop(_wrap_binop(operator.sub)) -mul = defop(_wrap_binop(operator.mul)) -truediv = defop(_wrap_binop(operator.truediv)) -pow = defop(_wrap_binop(operator.pow)) -abs = defop(_wrap_unop(operator.abs)) -eq = defop(_wrap_cmp(operator.eq)) - - -@defdata.register(numbers.Complex) -@numbers.Complex.register -class _ComplexTerm(_NumberTerm, Term[numbers.Complex]): - def __bool__(self) -> bool: - raise ValueError("Cannot convert term to bool") - - def __add__(self, other: Any) -> numbers.Real: - return add(self, other) - - def __radd__(self, other: Any) -> numbers.Real: - return add(other, self) - - def __neg__(self): - return neg(self) - - def __pos__(self): - return pos(self) - - def __sub__(self, other: Any) -> numbers.Real: - return sub(self, other) - - def __rsub__(self, other: Any) -> numbers.Real: - return sub(other, self) - - def __mul__(self, other: Any) -> numbers.Real: - return mul(self, other) - - def __rmul__(self, other: Any) -> numbers.Real: - return mul(other, self) - - def __truediv__(self, other: Any) -> numbers.Real: - return truediv(self, other) - - def __rtruediv__(self, other: Any) -> numbers.Real: - return truediv(other, self) - - def __pow__(self, other: Any) -> numbers.Real: - return pow(self, other) - - def __rpow__(self, other: Any) -> numbers.Real: - return pow(other, self) - - def __abs__(self) -> numbers.Real: - return abs(self) - - def __eq__(self, other: Any) -> bool: - return eq(self, other) - - -# Real specific methods -floordiv = defop(_wrap_binop(operator.floordiv)) -mod = defop(_wrap_binop(operator.mod)) -lt = defop(_wrap_cmp(operator.lt)) -le = defop(_wrap_cmp(operator.le)) -gt = defop(_wrap_cmp(operator.gt)) -ge = defop(_wrap_cmp(operator.ge)) - - -@defdata.register(numbers.Real) -@numbers.Real.register -class _RealTerm(_ComplexTerm, Term[numbers.Real]): - # Real specific methods - def __float__(self) -> float: - raise ValueError("Cannot convert term to float") - - def __trunc__(self) -> numbers.Integral: - raise NotImplementedError - - def __floor__(self) -> numbers.Integral: - raise NotImplementedError - - def __ceil__(self) -> numbers.Integral: - raise NotImplementedError - - def __round__(self, ndigits=None) -> numbers.Integral: - raise NotImplementedError - - def __floordiv__(self, other): - return floordiv(self, other) - - def __rfloordiv__(self, other): - return floordiv(other, self) - - def __mod__(self, other): - return mod(self, other) - - def __rmod__(self, other): - return mod(other, self) - - def __lt__(self, other): - return lt(self, other) - - def __le__(self, other): - return le(self, other) - - -@defdata.register(numbers.Rational) -@numbers.Rational.register -class _RationalTerm(_RealTerm, Term[numbers.Rational]): - @property - def numerator(self): - raise NotImplementedError - - @property - def denominator(self): - raise NotImplementedError - - -# Integral specific methods -index = defop(_wrap_unop(operator.index)) -lshift = defop(_wrap_binop(operator.lshift)) -rshift = defop(_wrap_binop(operator.rshift)) -and_ = defop(_wrap_binop(operator.and_)) -xor = defop(_wrap_binop(operator.xor)) -or_ = defop(_wrap_binop(operator.or_)) -invert = defop(_wrap_unop(operator.invert)) - - -@defdata.register(numbers.Integral) -@numbers.Integral.register -class _IntegralTerm(_RationalTerm, Term[numbers.Integral]): - # Integral specific methods - def __int__(self) -> int: - raise ValueError("Cannot convert term to int") - - def __index__(self) -> numbers.Integral: - return index(self) - - def __pow__(self, exponent: Any, modulus=None) -> numbers.Integral: - return pow(self, exponent) - - def __lshift__(self, other): - return lshift(self, other) - - def __rlshift__(self, other): - return lshift(other, self) - - def __rshift__(self, other): - return rshift(self, other) - - def __rrshift__(self, other): - return rshift(other, self) - - def __and__(self, other): - return and_(self, other) - - def __rand__(self, other): - return and_(other, self) - - def __xor__(self, other): - return xor(self, other) - - def __rxor__(self, other): - return xor(other, self) - - def __or__(self, other): - return or_(self, other) - - def __ror__(self, other): - return or_(other, self) - - def __invert__(self): - return invert(self) +from effectful.ops.syntax import defdata + +add = defdata.dispatch(numbers.Number).__add__ +neg = defdata.dispatch(numbers.Number).__neg__ +pos = defdata.dispatch(numbers.Number).__pos__ +sub = defdata.dispatch(numbers.Number).__sub__ +mul = defdata.dispatch(numbers.Number).__mul__ +truediv = defdata.dispatch(numbers.Number).__truediv__ +pow = defdata.dispatch(numbers.Number).__pow__ +abs = defdata.dispatch(numbers.Number).__abs__ +floordiv = defdata.dispatch(numbers.Number).__floordiv__ +mod = defdata.dispatch(numbers.Number).__mod__ +eq = defdata.dispatch(numbers.Number).__eq__ +lt = defdata.dispatch(numbers.Number).__lt__ +le = defdata.dispatch(numbers.Number).__le__ +gt = defdata.dispatch(numbers.Number).__gt__ +ge = defdata.dispatch(numbers.Number).__ge__ +index = defdata.dispatch(numbers.Number).__index__ +lshift = defdata.dispatch(numbers.Number).__lshift__ +rshift = defdata.dispatch(numbers.Number).__rshift__ +and_ = defdata.dispatch(numbers.Number).__and__ +xor = defdata.dispatch(numbers.Number).__xor__ +or_ = defdata.dispatch(numbers.Number).__or__ +invert = defdata.dispatch(numbers.Number).__invert__ diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 9fea27d6..8bd81675 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -1,3 +1,4 @@ +import collections.abc import contextlib import functools import types @@ -5,8 +6,6 @@ from collections.abc import Callable from typing import Any -import tree - from effectful.ops.syntax import deffn, defop from effectful.ops.types import Expr, Interpretation, Operation, Term @@ -68,7 +67,7 @@ def call[**P, T](fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: } with handler(subs): return evaluate(body) - elif not any(isinstance(a, Term) for a in tree.flatten((fn, args, kwargs))): + elif not any(isinstance(a, Term) for a in (fn, *args, *kwargs.values())): return fn(*args, **kwargs) else: raise NotImplementedError @@ -246,18 +245,24 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T] 6 """ - if intp is None: - from effectful.internals.runtime import get_interpretation + from effectful.internals.runtime import get_interpretation, interpreter - intp = get_interpretation() + if intp is not None: + return interpreter(intp)(evaluate)(expr) if isinstance(expr, Term): - (args, kwargs) = tree.map_structure( - functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs) - ) - return apply.__default_rule__(intp, expr.op, *args, **kwargs) - elif tree.is_nested(expr): - return tree.map_structure(functools.partial(evaluate, intp=intp), expr) + args = tuple(evaluate(arg) for arg in expr.args) + kwargs = {k: evaluate(v) for k, v in expr.kwargs.items()} + return expr.op(*args, **kwargs) + elif isinstance(expr, collections.abc.Mapping): + return type(expr)((evaluate(k), evaluate(v)) for k, v in expr.items()) + elif isinstance(expr, collections.abc.Sequence) and not isinstance(expr, str): + return type(expr)(evaluate(e) for e in expr) + elif isinstance(expr, collections.abc.Set): + return type(expr)(evaluate(e) for e in expr) + elif isinstance(expr, Operation): + op_intp = get_interpretation().get(expr, expr) + return op_intp if isinstance(op_intp, Operation) else expr else: return expr diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 8a39eb9e..0ef29de3 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -2,14 +2,14 @@ import dataclasses import functools import inspect +import numbers +import operator import random import types import typing from collections.abc import Callable, Iterable, Mapping from typing import Annotated, Concatenate -import tree - from effectful.ops.types import Annotation, Expr, Operation, Term @@ -355,16 +355,17 @@ def analyze(self, bound_sig: inspect.BoundArguments) -> frozenset[Operation]: else: param_bound_vars = {param_value} elif param_ordinal: # Only process if there's a Scoped annotation - # We can't use tree.flatten here because we want to be able - # to see dict keys + def extract_operations(obj): if isinstance(obj, Operation): param_bound_vars.add(obj) - elif isinstance(obj, dict): + elif isinstance(obj, collections.abc.Mapping): for k, v in obj.items(): extract_operations(k) extract_operations(v) - elif isinstance(obj, list | set | tuple): + elif isinstance( + obj, collections.abc.Sequence | collections.abc.Set + ): for v in obj: extract_operations(v) @@ -535,6 +536,10 @@ def __init__( self._freshening = freshening or [] self.__signature__ = inspect.signature(default) + @property + def __isabstractmethod__(self) -> bool: + return False + def __eq__(self, other): if not isinstance(other, Operation): return NotImplemented @@ -662,7 +667,7 @@ def func() -> t: # type: ignore def _[**P, T](t: Callable[P, T], *, name: str | None = None) -> Operation[P, T]: @functools.wraps(t) def func(*args, **kwargs): - if not any(isinstance(a, Term) for a in tree.flatten((args, kwargs))): + if not any(isinstance(a, Term) for a in (*args, *kwargs.values())): return t(*args, **kwargs) else: raise NotImplementedError @@ -872,18 +877,6 @@ def defterm[T](__dispatch: Callable[[type], Callable[[T], Expr[T]]], value: T): return __dispatch(type(value))(value) -def _map_structure_and_keys(func, structure): - def _map_value(value): - if isinstance(value, dict): - return {func(k): v for k, v in value.items()} - elif not tree.is_nested(value): - return func(value) - else: - return value - - return tree.traverse(_map_value, structure, top_down=False) - - @_CustomSingleDispatchCallable def defdata[T]( __dispatch: Callable[[type], Callable[..., Expr[T]]], @@ -960,9 +953,6 @@ def _(op, *args, **kwargs): *{k: (v, kwarg_ctxs[k]) for k, v in kwargs.items()}.items(), ): if c: - v = _map_structure_and_keys( - lambda a: renaming.get(a, a) if isinstance(a, Operation) else a, v - ) res = evaluate( v, intp={ @@ -982,6 +972,8 @@ def _(op, *args, **kwargs): return typed_term +@defterm.register(str) +@defterm.register(bytes) @defterm.register(object) @defterm.register(Operation) @defterm.register(Term) @@ -991,48 +983,6 @@ def _[T](value: T) -> T: return value -@defdata.register(object) -class _BaseTerm[T](Term[T]): - _op: Operation[..., T] - _args: collections.abc.Sequence[Expr] - _kwargs: collections.abc.Mapping[str, Expr] - - def __init__( - self, - op: Operation[..., T], - *args: Expr, - **kwargs: Expr, - ): - self._op = op - self._args = args - self._kwargs = kwargs - - def __eq__(self, other) -> bool: - from effectful.ops.syntax import syntactic_eq - - return syntactic_eq(self, other) - - @property - def op(self): - return self._op - - @property - def args(self): - return self._args - - @property - def kwargs(self): - return self._kwargs - - -@defdata.register(collections.abc.Callable) -class _CallableTerm[**P, T](_BaseTerm[collections.abc.Callable[P, T]]): - def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]: - from effectful.ops.semantics import call - - return call(self, *args, **kwargs) # type: ignore - - def trace[**P, T](value: Callable[P, T]) -> Callable[P, T]: """Convert a callable to a term by calling it with appropriately typed free variables. @@ -1097,30 +1047,6 @@ def defstream[S, T, A, B]( raise NotImplementedError -@defdata.register(collections.abc.Iterable) -class _IterableTerm[T](_BaseTerm[collections.abc.Iterable[T]]): - @defop - def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: - if not isinstance(self, Term): - return iter(self) - else: - raise NotImplementedError - - -@defdata.register(collections.abc.Iterator) -class _IteratorTerm[T](_IterableTerm[T]): - @defop - def __next__(self: collections.abc.Iterator[T]) -> T: - if not isinstance(self, Term): - return next(self) - else: - raise NotImplementedError - - -iter_ = _IterableTerm.__iter__ -next_ = _IteratorTerm.__next__ - - def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool: """Syntactic equality, ignoring the interpretation of the terms. @@ -1133,21 +1059,28 @@ def syntactic_eq[T](x: Expr[T], other: Expr[T]) -> bool: if isinstance(x, Term) and isinstance(other, Term): op, args, kwargs = x.op, x.args, x.kwargs op2, args2, kwargs2 = other.op, other.args, other.kwargs - try: - tree.assert_same_structure( - (op, args, kwargs), (op2, args2, kwargs2), check_types=True - ) - except (TypeError, ValueError): - return False - return all( - tree.flatten( - tree.map_structure( - syntactic_eq, (op, args, kwargs), (op2, args2, kwargs2) - ) - ) + return ( + op == op2 + and len(args) == len(args2) + and set(kwargs) == set(kwargs2) + and all(syntactic_eq(a, b) for a, b in zip(args, args2)) + and all(syntactic_eq(kwargs[k], kwargs2[k]) for k in kwargs) ) elif isinstance(x, Term) or isinstance(other, Term): return False + elif isinstance(x, collections.abc.Mapping) and isinstance( + other, collections.abc.Mapping + ): + return all( + k in x and k in other and syntactic_eq(x[k], other[k]) + for k in set(x) | set(other) + ) + elif isinstance(x, collections.abc.Sequence) and isinstance( + other, collections.abc.Sequence + ): + return len(x) == len(other) and all( + syntactic_eq(a, b) for a, b in zip(x, other) + ) else: return x == other @@ -1263,3 +1196,606 @@ def implements[**P, V](op: Operation[P, V]): """ return _ImplementedOperation(op) + + +@defdata.register(object) +class _BaseTerm[T](Term[T]): + _op: Operation[..., T] + _args: collections.abc.Sequence[Expr] + _kwargs: collections.abc.Mapping[str, Expr] + + def __init__( + self, + op: Operation[..., T], + *args: Expr, + **kwargs: Expr, + ): + self._op = op + self._args = args + self._kwargs = kwargs + + def __eq__(self, other) -> bool: + from effectful.ops.syntax import syntactic_eq + + return syntactic_eq(self, other) + + @property + def op(self): + return self._op + + @property + def args(self): + return self._args + + @property + def kwargs(self): + return self._kwargs + + +@defdata.register(collections.abc.Callable) +class _CallableTerm[**P, T](_BaseTerm[collections.abc.Callable[P, T]]): + def __call__(self, *args: Expr, **kwargs: Expr) -> Expr[T]: + from effectful.ops.semantics import call + + return call(self, *args, **kwargs) # type: ignore + + +@defdata.register(collections.abc.Iterable) +class _IterableTerm[T]( + _BaseTerm[collections.abc.Iterable[T]], collections.abc.Iterable[T] +): + @defop + def __iter__(self: collections.abc.Iterable[T]) -> collections.abc.Iterator[T]: + if not isinstance(self, Term): + return self.__iter__() + else: + raise NotImplementedError + + +@defdata.register(collections.abc.Iterator) +class _IteratorTerm[T](_IterableTerm[T], collections.abc.Iterator[T]): + @defop + def __next__(self: collections.abc.Iterator[T]) -> T: + if not isinstance(self, Term): + return self.__next__() + else: + raise NotImplementedError + + +# @defdata.register(collections.abc.Reversible) +class _ReversibleTerm[T](_IterableTerm[T], collections.abc.Reversible[T]): + @defop + def __reversed__( + self: collections.abc.Reversible[T], + ) -> collections.abc.Iterator[T]: + if not isinstance(self, Term): + return self.__reversed__() + else: + raise NotImplementedError + + +@defdata.register(collections.abc.Collection) +class _CollectionTerm[T](_IterableTerm[T], collections.abc.Collection[T]): + @defop + def __len__(self: collections.abc.Sized) -> int: + if not isinstance(self, Term): + return self.__len__() + else: + raise NotImplementedError + + @defop + def __contains__(self: collections.abc.Container[T], item: T) -> bool: + if not isinstance(self, Term) and not isinstance(item, Term): + return self.__contains__(item) + else: + raise NotImplementedError + + +@defdata.register(collections.abc.Sequence) +class _SequenceTerm[T]( + _CollectionTerm[T], _ReversibleTerm[T], collections.abc.Sequence[T] +): + @defop + def __getitem__(self: collections.abc.Sequence[T], index: int | slice) -> T: + if not isinstance(self, Term) and not isinstance(index, Term): + return self.__getitem__(index) + else: + raise NotImplementedError + + @defop + def index( + self: collections.abc.Sequence[T], + value: T, + start: int = 0, + stop: int | None = None, + ) -> int: + if not isinstance(self, Term) and not isinstance(value, Term): + return self.index(value, start, stop) + else: + raise NotImplementedError + + @defop + def count(self: collections.abc.Sequence[T], value: T) -> int: + if not isinstance(self, Term) and not isinstance(value, Term): + return self.count(value) + else: + raise NotImplementedError + + +@defdata.register(collections.abc.Set) +@functools.total_ordering +class _SetTerm[T](_CollectionTerm[T], collections.abc.Set[T]): + @defop + def __eq__(self: collections.abc.Set[T], other: collections.abc.Set[T]) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__eq__(other) + else: + raise NotImplementedError + + @defop + def __lt__(self: collections.abc.Set[T], other: collections.abc.Set[T]) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__lt__(other) + else: + raise NotImplementedError + + @defop + def __and__[V]( + self: collections.abc.Set[T], other: collections.abc.Set[V] + ) -> collections.abc.Set[T | V]: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__and__(other) + else: + raise NotImplementedError + + @defop + def __or__[V]( + self: collections.abc.Set[T], other: collections.abc.Set[V] + ) -> collections.abc.Set[T | V]: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__or__(other) + else: + raise NotImplementedError + + @defop + def __xor__[V]( + self: collections.abc.Set[T], other: collections.abc.Set[V] + ) -> collections.abc.Set[T | V]: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__xor__(other) + else: + raise NotImplementedError + + def __rxor__[V]( + self: collections.abc.Set[T], other: collections.abc.Set[V] + ) -> collections.abc.Set[T | V]: + """Right-hand side XOR operator for sets.""" + return type(self).__xor__(other, self) + + @defop + def __sub__[V]( + self: collections.abc.Set[T], other: collections.abc.Set[V] + ) -> collections.abc.Set[T]: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__sub__(other) + else: + raise NotImplementedError + + def __rsub__[V]( + self: collections.abc.Set[T], other: collections.abc.Set[V] + ) -> collections.abc.Set[T]: + """Right-hand side subtraction operator for sets.""" + return type(self).__sub__(other, self) + + @defop + def isdisjoint[V]( + self: collections.abc.Set[T], other: collections.abc.Iterable[V] + ) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.isdisjoint(other) + else: + raise NotImplementedError + + +@defdata.register(collections.abc.Mapping) +class _MappingTerm[S, V](_CollectionTerm[S], collections.abc.Mapping[S, V]): + @defop + def __getitem__(self: collections.abc.Mapping[S, V], key: S) -> V: + if not isinstance(self, Term) and not isinstance(key, Term): + return self.__getitem__(key) + else: + raise NotImplementedError + + @defop + def __eq__( + self: collections.abc.Mapping[S, V], other: collections.abc.Mapping[S, V] + ) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__eq__(other) + else: + raise NotImplementedError + + @defop + def keys(self: collections.abc.Mapping[S, V]) -> collections.abc.Set[S]: + if not isinstance(self, Term): + return self.keys() + else: + raise NotImplementedError + + @defop + def values(self: collections.abc.Mapping[S, V]) -> collections.abc.Collection[V]: + if not isinstance(self, Term): + return self.values() + else: + raise NotImplementedError + + @defop + def items(self: collections.abc.Mapping[S, V]) -> collections.abc.Set[tuple[S, V]]: + if not isinstance(self, Term): + return self.items() + else: + raise NotImplementedError + + +@defdata.register(numbers.Number) +@functools.total_ordering +class _NumberTerm[T: numbers.Number](_BaseTerm[T], numbers.Number): + def __hash__(self): + return id(self) + + def __complex__(self) -> complex: + raise ValueError("Cannot convert term to complex number") + + def __float__(self) -> float: + raise ValueError("Cannot convert term to float") + + def __int__(self) -> int: + raise ValueError("Cannot convert term to int") + + def __bool__(self) -> bool: + raise ValueError("Cannot convert term to bool") + + @defop + @property + def real(self: numbers.Complex) -> float: + if not isinstance(self, Term): + return self.real + else: + raise NotImplementedError + + @defop + @property + def imag(self: numbers.Complex) -> float: + if not isinstance(self, Term): + return self.imag + else: + raise NotImplementedError + + @defop + def conjugate(self: T) -> T: + if not isinstance(self, Term): + return self.conjugate() + else: + raise NotImplementedError + + @defop + @property + def numerator(self: numbers.Rational) -> int: + if not isinstance(self, Term): + return self.numerator + else: + raise NotImplementedError + + @defop + @property + def denominator(self: numbers.Rational) -> int: + if not isinstance(self, Term): + return self.denominator + else: + raise NotImplementedError + + @defop + def __abs__(self: numbers.Complex) -> float: + """Return the absolute value of the term.""" + if not isinstance(self, Term): + return self.__abs__() + else: + raise NotImplementedError + + @defop + def __neg__(self: T) -> T: + if not isinstance(self, Term): + return self.__neg__() + else: + raise NotImplementedError + + @defop + def __pos__(self: T) -> T: + if not isinstance(self, Term): + return self.__pos__() + else: + raise NotImplementedError + + @defop + def __trunc__(self: numbers.Real) -> int: + if not isinstance(self, Term): + return self.__trunc__() + else: + raise NotImplementedError + + @defop + def __floor__(self: numbers.Real) -> int: + if not isinstance(self, Term): + return self.__floor__() + else: + raise NotImplementedError + + @defop + def __ceil__(self: numbers.Real) -> int: + if not isinstance(self, Term): + return self.__ceil__() + else: + raise NotImplementedError + + @defop + def __round__(self: numbers.Real, ndigits: int | None = None) -> int | float: + if not isinstance(self, Term) and not isinstance(ndigits, Term): + return self.__round__(ndigits) + else: + raise NotImplementedError + + @defop + def __invert__(self: numbers.Integral) -> int: + if not isinstance(self, Term): + return self.__invert__() + else: + raise NotImplementedError + + @defop + def __index__(self: numbers.Integral) -> int: + if not isinstance(self, Term): + return self.__index__() + else: + raise NotImplementedError + + @defop + def __eq__(self: numbers.Complex, other: numbers.Complex) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__eq__(other) + else: + return syntactic_eq(self, other) + + @defop + def __lt__(self: numbers.Real, other: numbers.Real) -> bool: + if not isinstance(self, Term) and not isinstance(other, Term): + return self.__lt__(other) + else: + raise NotImplementedError + + @defop + def __add__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__add__(self, other) + else: + raise NotImplementedError + + def __radd__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__add__(self) + elif not isinstance(other, Term): + return type(self).__add__(other, self) + else: + return NotImplemented + + @defop + def __sub__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__sub__(self, other) + else: + raise NotImplementedError + + def __rsub__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__sub__(self) + elif not isinstance(other, Term): + return type(self).__sub__(other, self) + else: + return NotImplemented + + @defop + def __mul__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__mul__(self, other) + else: + raise NotImplementedError + + def __rmul__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__mul__(self) + elif not isinstance(other, Term): + return type(self).__mul__(other, self) + else: + return NotImplemented + + @defop + def __truediv__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__truediv__(self, other) + else: + raise NotImplementedError + + def __rtruediv__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__truediv__(self) + elif not isinstance(other, Term): + return type(self).__truediv__(other, self) + else: + return NotImplemented + + @defop + def __floordiv__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__floordiv__(self, other) + else: + raise NotImplementedError + + def __rfloordiv__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__floordiv__(self) + elif not isinstance(other, Term): + return type(self).__floordiv__(other, self) + else: + return NotImplemented + + @defop + def __mod__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__mod__(self, other) + else: + raise NotImplementedError + + def __rmod__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__mod__(self) + elif not isinstance(other, Term): + return type(self).__mod__(other, self) + else: + return NotImplemented + + @defop + def __pow__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__pow__(self, other) + else: + raise NotImplementedError + + def __rpow__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__pow__(self) + elif not isinstance(other, Term): + return type(self).__pow__(other, self) + else: + return NotImplemented + + @defop + def __lshift__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__lshift__(self, other) + else: + raise NotImplementedError + + def __rlshift__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__lshift__(self) + elif not isinstance(other, Term): + return type(self).__lshift__(other, self) + else: + return NotImplemented + + @defop + def __rshift__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__rshift__(self, other) + else: + raise NotImplementedError + + def __rrshift__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__rshift__(self) + elif not isinstance(other, Term): + return type(self).__rshift__(other, self) + else: + return NotImplemented + + @defop + def __and__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__and__(self, other) + else: + raise NotImplementedError + + def __rand__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__and__(self) + elif not isinstance(other, Term): + return type(self).__and__(other, self) + else: + return NotImplemented + + @defop + def __xor__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__xor__(self, other) + else: + raise NotImplementedError + + def __rxor__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__xor__(self) + elif not isinstance(other, Term): + return type(self).__xor__(other, self) + else: + return NotImplemented + + @defop + def __or__(self: T, other: T) -> T: + if not isinstance(self, Term) and not isinstance(other, Term): + return operator.__or__(self, other) + else: + raise NotImplementedError + + def __ror__(self, other): + if isinstance(other, Term) and isinstance(other, type(self)): + return other.__or__(self) + elif not isinstance(other, Term): + return type(self).__or__(other, self) + else: + return NotImplemented + + +# # numbers.Complex polymorphic binary methods +# __add__, __radd__ = _wrapper_generator(operator.__add__) +# __sub__, __rsub__ = _wrapper_generator(operator.__sub__) +# __mul__, __rmul__ = _wrapper_generator(operator.__mul__) +# __truediv__, __rtruediv__ = _wrapper_generator(operator.__truediv__) +# __pow__, __rpow__ = _wrapper_generator(operator.__pow__) + +# # numbers.Real polymorphic binary methods +# __floordiv__, __rfloordiv__ = _wrapper_generator(operator.__floordiv__) +# __mod__, __rmod__ = _wrapper_generator(operator.__mod__) + +# # numbers.Integral polymorphic binary methods +# __lshift__, __rlshift__ = _wrapper_generator(operator.__lshift__) +# __rshift__, __rrshift__ = _wrapper_generator(operator.__rshift__) +# __and__, __rand__ = _wrapper_generator(operator.__and__) +# __xor__, __rxor__ = _wrapper_generator(operator.__xor__) +# __or__, __ror__ = _wrapper_generator(operator.__or__) + + +@defdata.register(numbers.Complex) +@numbers.Complex.register +class _ComplexTerm[T: numbers.Complex](_NumberTerm[T]): + pass + + +@defdata.register(numbers.Real) +@numbers.Real.register +class _RealTerm[T: numbers.Real](_ComplexTerm[T]): + pass + + +@defdata.register(numbers.Rational) +@numbers.Rational.register +class _RationalTerm[T: numbers.Rational](_RealTerm[T]): + pass + + +@defdata.register(numbers.Integral) +@numbers.Integral.register +class _IntegralTerm[T: numbers.Integral](_RationalTerm[T]): + pass + + +@defdata.register(bool) +class _BoolTerm[T: bool](_IntegralTerm[T]): + pass diff --git a/tests/test_handlers_numbers.py b/tests/test_handlers_numbers.py index 31d117dd..763a09ec 100644 --- a/tests/test_handlers_numbers.py +++ b/tests/test_handlers_numbers.py @@ -253,7 +253,6 @@ def test_evaluate_2(): y = defop(int, name="y") t = x() + y() assert isinstance(t, Term) - assert t.op.__name__ == "add" with handler({x: lambda: 1, y: lambda: 3}): assert evaluate(t) == 4 diff --git a/tests/test_ops_syntax.py b/tests/test_ops_syntax.py index d4e06de1..d35b5bcc 100644 --- a/tests/test_ops_syntax.py +++ b/tests/test_ops_syntax.py @@ -1,3 +1,4 @@ +import collections.abc import functools import inspect from collections.abc import Callable, Iterable, Iterator, Mapping @@ -10,13 +11,10 @@ from effectful.ops.syntax import ( Scoped, _CustomSingleDispatchCallable, - _map_structure_and_keys, deffn, defop, defstream, defterm, - iter_, - next_, ) from effectful.ops.types import Operation, Term @@ -111,13 +109,6 @@ def f(x): assert f_op != ff_op -def test_map_structure_and_keys(): - s = {1: 2, 3: [4, 5, (6, {7: 8})]} - expected = {2: 3, 4: [5, 6, (7, {8: 9})]} - actual = _map_structure_and_keys(lambda x: x + 1, s) - assert actual == expected - - def test_scoped_collections(): """Test that Scoped annotations work with tree-structured collections containing Operations.""" @@ -205,11 +196,11 @@ def test_term_str(): assert str(x1) == str(x2) == str(x3) == "x" assert repr(x1) != repr(x2) != repr(x3) - assert str(x1() + x2()) == "add(x(), x!1())" - assert str(x1() + x1()) == "add(x(), x())" - assert str(deffn(x1() + x1(), x1)) == "deffn(add(x(), x()), x)" - assert str(deffn(x1() + x1(), x2)) == "deffn(add(x(), x()), x!1)" - assert str(deffn(x1() + x2(), x1)) == "deffn(add(x(), x!1()), x)" + assert str(x1() + x2()) == "__add__(x(), x!1())" + assert str(x1() + x1()) == "__add__(x(), x())" + assert str(deffn(x1() + x1(), x1)) == "deffn(__add__(x(), x()), x)" + assert str(deffn(x1() + x1(), x2)) == "deffn(__add__(x(), x()), x!1)" + assert str(deffn(x1() + x2(), x1)) == "deffn(__add__(x(), x!1()), x)" def test_defop_singledispatch(): @@ -514,13 +505,13 @@ def cons_iterable(*args: int) -> Iterable[int]: assert isinstance(tm_iter, Term) assert isinstance(tm_iter, Iterator) assert issubclass(typeof(tm_iter), Iterator) - assert tm_iter.op is iter_ + assert tm_iter.op is type(tm).__iter__ tm_iter_next = next(tm_iter) assert isinstance(tm_iter_next, Term) # assert isinstance(tm_iter_next, numbers.Number) # TODO # assert issubclass(typeof(tm_iter_next), numbers.Number) - assert tm_iter_next.op is next_ + assert tm_iter_next.op is type(tm_iter).__next__ assert list(tm.args) == [1, 2, 3] @@ -542,10 +533,61 @@ def test_defstream_1(): assert isinstance(tm_iter, Term) assert isinstance(tm_iter, Iterator) assert issubclass(typeof(tm_iter), Iterator) - assert tm_iter.op is iter_ + assert tm_iter.op is type(tm).__iter__ tm_iter_next = next(tm_iter) assert isinstance(tm_iter_next, Term) # assert isinstance(tm_iter_next, numbers.Number) # TODO # assert issubclass(typeof(tm_iter_next), numbers.Number) - assert tm_iter_next.op is next_ + assert tm_iter_next.op is type(tm_iter).__next__ + + +def test_defterm_sequence(): + @defop + def my_sequence(xs: tuple[int, ...]) -> tuple[int, ...]: + raise NotImplementedError + + x = defop(int, name="x") + tm = my_sequence((4, 5, x() + 1)) + + assert isinstance(tm, Term) + assert isinstance(tm, collections.abc.Sequence) + assert issubclass(typeof(tm), collections.abc.Sequence) + assert tm.op is my_sequence + assert tm.args == ((4, 5, x() + 1),) + + tm_0 = tm[0] + assert isinstance(tm_0, Term) + assert isinstance(tm_0.op, Operation) + assert tm_0.op is type(tm).__getitem__ + assert tm_0.args == (tm, 0) + + # Test that the term can be evaluated + with handler({my_sequence: lambda xs: tuple(x * 2 for x in xs), x: lambda: 0}): + assert evaluate(tm) == (8, 10, 2) + + +def test_defterm_mapping(): + @defop + def my_mapping(k: int, v: int) -> dict[int, int]: + raise NotImplementedError + + x = defop(int, name="x") + tm = my_mapping(3, x() + 1) + + assert isinstance(tm, Term) + assert isinstance(tm, collections.abc.Mapping) + assert issubclass(typeof(tm), collections.abc.Mapping) + assert tm.op is my_mapping + assert tm.args == (3, x() + 1) + assert tm.kwargs == {} + + tm_3 = tm[3] + assert isinstance(tm_3, Term) + assert isinstance(tm_3.op, Operation) + assert tm_3.op is type(tm).__getitem__ + assert tm_3.args == (tm, 3) + + # Test that the term can be evaluated + with handler({my_mapping: lambda k, v: {k: v}, x: lambda: 0}): + assert evaluate(tm) == {3: 1}