Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
003cd54
stash
eb8680 Jul 1, 2025
4ecc29e
avoid defterm
eb8680 Jul 1, 2025
76abd88
format
eb8680 Jul 1, 2025
abd67c6
nit
eb8680 Jul 1, 2025
0aae217
stash
eb8680 Jul 2, 2025
24bf5cb
fix bugs
eb8680 Jul 3, 2025
284ac3e
nit
eb8680 Jul 3, 2025
49f678f
first commit
eb8680 Jul 3, 2025
d99753b
format
eb8680 Jul 3, 2025
b4417f3
test passes
eb8680 Jul 3, 2025
02eb5ec
types
eb8680 Jul 3, 2025
8a5152d
add docstring with doctests for unify
eb8680 Jul 3, 2025
13877f2
more test cases
eb8680 Jul 3, 2025
0dbd307
case
eb8680 Jul 3, 2025
e675bf6
docstring and test cases for infer_return_type
eb8680 Jul 3, 2025
90dd9fb
canonicalize
eb8680 Jul 4, 2025
c95ea01
fixes
eb8680 Jul 4, 2025
4a1461f
fix error type
eb8680 Jul 4, 2025
a27d811
add _nested_type for sequences
eb8680 Jul 4, 2025
c5c09a3
make _nested_type singledispatch extensible
eb8680 Jul 4, 2025
ce12433
term case
eb8680 Jul 4, 2025
f2492af
nit
eb8680 Jul 4, 2025
0cb5516
docstrings and tests for canonicalize and nested_type
eb8680 Jul 4, 2025
87dd5c8
no notimmplentederror
eb8680 Jul 4, 2025
841fc49
variadic params
eb8680 Jul 4, 2025
c1d102a
remove dead code paths in unify
eb8680 Jul 4, 2025
25eaa57
simplify unify
eb8680 Jul 4, 2025
dfd3507
union
eb8680 Jul 4, 2025
6d67944
tweak
eb8680 Jul 4, 2025
652061c
cleanup
eb8680 Jul 4, 2025
c651639
fix none case
eb8680 Jul 7, 2025
a6f0188
more aggressive canonicalize applied only during unification
eb8680 Jul 7, 2025
e79b5d3
update canonicalize test
eb8680 Jul 8, 2025
ef34247
remove dead path
eb8680 Jul 8, 2025
748bcac
paramspec failure
eb8680 Jul 8, 2025
08d149f
literal and optional
eb8680 Jul 8, 2025
0efe01d
paramspec canonicalize
eb8680 Jul 8, 2025
b52c7df
simplify alias
eb8680 Jul 8, 2025
c5f0f1b
single unify call
eb8680 Jul 8, 2025
bd006b8
fast paths
eb8680 Jul 8, 2025
84e8734
union handling
eb8680 Jul 8, 2025
a6426ae
freshen
eb8680 Jul 8, 2025
3a8f726
type of freshen
eb8680 Jul 8, 2025
4639ce0
reorder
eb8680 Jul 8, 2025
ab7b1ff
fixpoint in substitute
eb8680 Jul 8, 2025
de9e7b6
tweak types
eb8680 Jul 8, 2025
9dc9acd
empty params
eb8680 Jul 8, 2025
7fd9927
tweak substitute type
eb8680 Jul 8, 2025
7330f23
doctest
eb8680 Jul 8, 2025
acf9f11
separate nested_type
eb8680 Jul 8, 2025
9ec83f8
handle defaults
eb8680 Jul 8, 2025
d7fc53b
update
eb8680 Jul 10, 2025
2a99670
remove test ids
eb8680 Jul 10, 2025
3e01b8c
add to sphinx
eb8680 Jul 10, 2025
13bbf74
fix doctests
eb8680 Jul 10, 2025
f655dce
fix type checking
eb8680 Jul 10, 2025
bc371ee
str
eb8680 Jul 10, 2025
2b5db42
lint and format
eb8680 Jul 10, 2025
2dddc92
freshen -> _freshen
eb8680 Jul 10, 2025
2199d63
factor out freetypevars and susbtitute
eb8680 Jul 10, 2025
c6fe3fc
truncate names
eb8680 Jul 10, 2025
97f4862
doctest
eb8680 Jul 10, 2025
23de369
doctest
eb8680 Jul 10, 2025
f8247da
add compositional tests
eb8680 Jul 10, 2025
e173387
add tests for function types
eb8680 Jul 10, 2025
6b532af
ellipsis and paramspec
eb8680 Jul 10, 2025
d9578bb
variadic tuple logic and union tests
eb8680 Jul 10, 2025
05a7849
fix union unify pattern
eb8680 Jul 10, 2025
e0cf9b3
Use unification to implement `Operation.__type_rule__` (#300)
eb8680 Jul 11, 2025
e8c1f91
Merge branch 'master' into eb-type-inference
eb8680 Jul 11, 2025
92f557a
Merge branch 'master' into eb-type-inference
eb8680 Jul 11, 2025
2345f12
update to 3.12
eb8680 Jul 11, 2025
86c5431
union
eb8680 Jul 11, 2025
3811b83
try truncating??
eb8680 Jul 11, 2025
4263617
tweaks
eb8680 Jul 11, 2025
3094ad4
Merge branch 'master' into eb-notree
eb8680 Jul 11, 2025
c8bb57c
format, add set term
eb8680 Jul 11, 2025
7add441
ops
eb8680 Jul 11, 2025
fcd3248
Merge branch 'eb-type-inference' into eb-notree-unify
eb8680 Jul 11, 2025
507ee1f
stash work on collections
eb8680 Jul 13, 2025
187a80d
Merge branch 'master' into eb-collections
eb8680 Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 24 additions & 240 deletions effectful/handlers/numbers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,244 +3,28 @@
"""

import numbers
import operator
from typing import Any

from effectful.ops.syntax import defdata, defop
from effectful.ops.types import Expr, Operation, Term


@defdata.register(numbers.Number)
@numbers.Number.register
class _NumberTerm(Term[numbers.Number]):
def __init__(
self, op: Operation[..., numbers.Number], *args: Expr, **kwargs: Expr
) -> None:
self._op = op
self._args = args
self._kwargs = kwargs

@property
def op(self) -> Operation[..., numbers.Number]:
return self._op

@property
def args(self) -> tuple:
return self._args

@property
def kwargs(self) -> dict:
return self._kwargs

def __hash__(self):
return hash((self.op, tuple(self.args), tuple(self.kwargs.items())))


# Complex specific methods
def _wrap_cmp(op):
def _wrapped_op[T_Number: numbers.Number](x: T_Number, y: T_Number) -> bool:
if not any(isinstance(a, Term) for a in (x, y)):
return op(x, y)
else:
raise NotImplementedError

_wrapped_op.__name__ = op.__name__
return _wrapped_op


def _wrap_binop(op):
def _wrapped_op[T_Number: numbers.Number](x: T_Number, y: T_Number) -> T_Number:
if not any(isinstance(a, Term) for a in (x, y)):
return op(x, y)
else:
raise NotImplementedError

_wrapped_op.__name__ = op.__name__
return _wrapped_op


def _wrap_unop(op):
def _wrapped_op[T_Number: numbers.Number](x: T_Number) -> T_Number:
if not isinstance(x, Term):
return op(x)
else:
raise NotImplementedError

_wrapped_op.__name__ = op.__name__
return _wrapped_op


add = defop(_wrap_binop(operator.add))
neg = defop(_wrap_unop(operator.neg))
pos = defop(_wrap_unop(operator.pos))
sub = defop(_wrap_binop(operator.sub))
mul = defop(_wrap_binop(operator.mul))
truediv = defop(_wrap_binop(operator.truediv))
pow = defop(_wrap_binop(operator.pow))
abs = defop(_wrap_unop(operator.abs))
eq = defop(_wrap_cmp(operator.eq))


@defdata.register(numbers.Complex)
@numbers.Complex.register
class _ComplexTerm(_NumberTerm, Term[numbers.Complex]):
def __bool__(self) -> bool:
raise ValueError("Cannot convert term to bool")

def __add__(self, other: Any) -> numbers.Real:
return add(self, other)

def __radd__(self, other: Any) -> numbers.Real:
return add(other, self)

def __neg__(self):
return neg(self)

def __pos__(self):
return pos(self)

def __sub__(self, other: Any) -> numbers.Real:
return sub(self, other)

def __rsub__(self, other: Any) -> numbers.Real:
return sub(other, self)

def __mul__(self, other: Any) -> numbers.Real:
return mul(self, other)

def __rmul__(self, other: Any) -> numbers.Real:
return mul(other, self)

def __truediv__(self, other: Any) -> numbers.Real:
return truediv(self, other)

def __rtruediv__(self, other: Any) -> numbers.Real:
return truediv(other, self)

def __pow__(self, other: Any) -> numbers.Real:
return pow(self, other)

def __rpow__(self, other: Any) -> numbers.Real:
return pow(other, self)

def __abs__(self) -> numbers.Real:
return abs(self)

def __eq__(self, other: Any) -> bool:
return eq(self, other)


# Real specific methods
floordiv = defop(_wrap_binop(operator.floordiv))
mod = defop(_wrap_binop(operator.mod))
lt = defop(_wrap_cmp(operator.lt))
le = defop(_wrap_cmp(operator.le))
gt = defop(_wrap_cmp(operator.gt))
ge = defop(_wrap_cmp(operator.ge))


@defdata.register(numbers.Real)
@numbers.Real.register
class _RealTerm(_ComplexTerm, Term[numbers.Real]):
# Real specific methods
def __float__(self) -> float:
raise ValueError("Cannot convert term to float")

def __trunc__(self) -> numbers.Integral:
raise NotImplementedError

def __floor__(self) -> numbers.Integral:
raise NotImplementedError

def __ceil__(self) -> numbers.Integral:
raise NotImplementedError

def __round__(self, ndigits=None) -> numbers.Integral:
raise NotImplementedError

def __floordiv__(self, other):
return floordiv(self, other)

def __rfloordiv__(self, other):
return floordiv(other, self)

def __mod__(self, other):
return mod(self, other)

def __rmod__(self, other):
return mod(other, self)

def __lt__(self, other):
return lt(self, other)

def __le__(self, other):
return le(self, other)


@defdata.register(numbers.Rational)
@numbers.Rational.register
class _RationalTerm(_RealTerm, Term[numbers.Rational]):
@property
def numerator(self):
raise NotImplementedError

@property
def denominator(self):
raise NotImplementedError


# Integral specific methods
index = defop(_wrap_unop(operator.index))
lshift = defop(_wrap_binop(operator.lshift))
rshift = defop(_wrap_binop(operator.rshift))
and_ = defop(_wrap_binop(operator.and_))
xor = defop(_wrap_binop(operator.xor))
or_ = defop(_wrap_binop(operator.or_))
invert = defop(_wrap_unop(operator.invert))


@defdata.register(numbers.Integral)
@numbers.Integral.register
class _IntegralTerm(_RationalTerm, Term[numbers.Integral]):
# Integral specific methods
def __int__(self) -> int:
raise ValueError("Cannot convert term to int")

def __index__(self) -> numbers.Integral:
return index(self)

def __pow__(self, exponent: Any, modulus=None) -> numbers.Integral:
return pow(self, exponent)

def __lshift__(self, other):
return lshift(self, other)

def __rlshift__(self, other):
return lshift(other, self)

def __rshift__(self, other):
return rshift(self, other)

def __rrshift__(self, other):
return rshift(other, self)

def __and__(self, other):
return and_(self, other)

def __rand__(self, other):
return and_(other, self)

def __xor__(self, other):
return xor(self, other)

def __rxor__(self, other):
return xor(other, self)

def __or__(self, other):
return or_(self, other)

def __ror__(self, other):
return or_(other, self)

def __invert__(self):
return invert(self)
from effectful.ops.syntax import defdata

add = defdata.dispatch(numbers.Number).__add__
neg = defdata.dispatch(numbers.Number).__neg__
pos = defdata.dispatch(numbers.Number).__pos__
sub = defdata.dispatch(numbers.Number).__sub__
mul = defdata.dispatch(numbers.Number).__mul__
truediv = defdata.dispatch(numbers.Number).__truediv__
pow = defdata.dispatch(numbers.Number).__pow__
abs = defdata.dispatch(numbers.Number).__abs__
floordiv = defdata.dispatch(numbers.Number).__floordiv__
mod = defdata.dispatch(numbers.Number).__mod__
eq = defdata.dispatch(numbers.Number).__eq__
lt = defdata.dispatch(numbers.Number).__lt__
le = defdata.dispatch(numbers.Number).__le__
gt = defdata.dispatch(numbers.Number).__gt__
ge = defdata.dispatch(numbers.Number).__ge__
index = defdata.dispatch(numbers.Number).__index__
lshift = defdata.dispatch(numbers.Number).__lshift__
rshift = defdata.dispatch(numbers.Number).__rshift__
and_ = defdata.dispatch(numbers.Number).__and__
xor = defdata.dispatch(numbers.Number).__xor__
or_ = defdata.dispatch(numbers.Number).__or__
invert = defdata.dispatch(numbers.Number).__invert__
29 changes: 17 additions & 12 deletions effectful/ops/semantics.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import collections.abc
import contextlib
import functools
import types
import typing
from collections.abc import Callable
from typing import Any

import tree

from effectful.ops.syntax import deffn, defop
from effectful.ops.types import Expr, Interpretation, Operation, Term

Expand Down Expand Up @@ -68,7 +67,7 @@ def call[**P, T](fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
}
with handler(subs):
return evaluate(body)
elif not any(isinstance(a, Term) for a in tree.flatten((fn, args, kwargs))):
elif not any(isinstance(a, Term) for a in (fn, *args, *kwargs.values())):
return fn(*args, **kwargs)
else:
raise NotImplementedError
Expand Down Expand Up @@ -246,18 +245,24 @@ def evaluate[T](expr: Expr[T], *, intp: Interpretation | None = None) -> Expr[T]
6

"""
if intp is None:
from effectful.internals.runtime import get_interpretation
from effectful.internals.runtime import get_interpretation, interpreter

intp = get_interpretation()
if intp is not None:
return interpreter(intp)(evaluate)(expr)

if isinstance(expr, Term):
(args, kwargs) = tree.map_structure(
functools.partial(evaluate, intp=intp), (expr.args, expr.kwargs)
)
return apply.__default_rule__(intp, expr.op, *args, **kwargs)
elif tree.is_nested(expr):
return tree.map_structure(functools.partial(evaluate, intp=intp), expr)
args = tuple(evaluate(arg) for arg in expr.args)
kwargs = {k: evaluate(v) for k, v in expr.kwargs.items()}
return expr.op(*args, **kwargs)
elif isinstance(expr, collections.abc.Mapping):
return type(expr)((evaluate(k), evaluate(v)) for k, v in expr.items())
elif isinstance(expr, collections.abc.Sequence) and not isinstance(expr, str):
return type(expr)(evaluate(e) for e in expr)
elif isinstance(expr, collections.abc.Set):
return type(expr)(evaluate(e) for e in expr)
elif isinstance(expr, Operation):
op_intp = get_interpretation().get(expr, expr)
return op_intp if isinstance(op_intp, Operation) else expr
else:
return expr

Expand Down
Loading
Loading