Skip to content

Commit 8afdb8b

Browse files
committed
Make __init_subclass__ work with nontrivial RHSes
Basically the idea is to evaluate it with the type variable substituted in. I had to rework the caching for boxes to prevent infinite recursion. It's still a TODO to properly report out GenericCallables for methods with nontrivial computation at the top level, but the basic idea is to *try* to get the annotation and raise a StuckException if we would need to use an operator on a variable while evaluating a Bool or Iter. Then if the exception got raised, we will produce a GenericCallable.
1 parent 03f4ad5 commit 8afdb8b

5 files changed

Lines changed: 113 additions & 72 deletions

File tree

tests/test_type_eval.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,14 +1690,11 @@ def g(self) -> int: ... # omitted
16901690

16911691
# Attrs
16921692
attrs = eval_typing(Attrs[B])
1693-
assert (
1694-
attrs
1695-
== tuple[
1696-
Member[Literal["a1"], int, Never, Never, A],
1697-
Member[Literal["a2"], str, Never, Never, B],
1698-
Member[Literal["b1"], str, Never, Never, B],
1699-
Member[Literal["b2"], str, Never, Never, B],
1700-
]
1693+
assert attrs.__args__ == (
1694+
Member[Literal["a1"], int, Never, Never, A],
1695+
Member[Literal["a2"], str, Never, Never, B],
1696+
Member[Literal["b1"], str, Never, Never, B],
1697+
Member[Literal["b2"], str, Never, Never, B],
17011698
)
17021699

17031700
# Members
@@ -1989,14 +1986,6 @@ def __init_subclass__[T](
19891986
)
19901987

19911988

1992-
type AttrsAsList[T] = UpdateClass[
1993-
*[Member[GetName[m], list[GetType[m]]] for m in Iter[Attrs[T]]]
1994-
]
1995-
type AttrsAsTuple[T] = UpdateClass[
1996-
*[Member[GetName[m], tuple[GetType[m]]] for m in Iter[Attrs[T]]]
1997-
]
1998-
1999-
20001989
def test_update_class_inheritance_02():
20011990
# __init_subclass__ calls follow normal MRO
20021991
class A:
@@ -2008,34 +1997,35 @@ def __init_subclass__[T](
20081997
super().__init_subclass__()
20091998

20101999
class B(A):
2011-
b: int
2000+
b: bytes
20122001

20132002
def __init_subclass__[T](
20142003
cls: type[T],
2015-
) -> AttrsAsList[T]:
2004+
) -> UpdateClass[
2005+
*[Member[GetName[m], list[GetType[m]]] for m in Iter[Attrs[T]]]
2006+
]:
20162007
super().__init_subclass__()
20172008

20182009
class C:
2019-
c: int
2010+
c: float
20202011

20212012
def __init_subclass__[T](
20222013
cls: type[T],
2023-
) -> AttrsAsTuple[T]:
2014+
) -> UpdateClass[
2015+
*[Member[GetName[m], tuple[GetType[m]]] for m in Iter[Attrs[T]]]
2016+
]:
20242017
super().__init_subclass__()
20252018

20262019
class D(B, C):
2027-
d: int
2020+
d: bool
20282021

20292022
attrs = eval_typing(Attrs[D])
20302023
# MRO = D, B, A, C, object
2031-
assert (
2032-
attrs
2033-
== tuple[
2034-
Member[Literal["c"], tuple[set[list[int]]], Never, Never, D],
2035-
Member[Literal["a"], tuple[set[list[int]]], Never, Never, D],
2036-
Member[Literal["b"], tuple[set[list[int]]], Never, Never, D],
2037-
Member[Literal["d"], tuple[set[list[int]]], Never, Never, D],
2038-
]
2024+
assert attrs.__args__ == (
2025+
Member[Literal["c"], tuple[set[list[float]]], Never, Never, D],
2026+
Member[Literal["a"], tuple[set[list[int]]], Never, Never, D],
2027+
Member[Literal["b"], tuple[set[list[bytes]]], Never, Never, D],
2028+
Member[Literal["d"], tuple[set[list[bool]]], Never, Never, D],
20392029
)
20402030

20412031

typemap/type_eval/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
eval_typing,
33
_get_current_context,
44
register_evaluator,
5+
StuckException,
56
_EvalProxy,
67
)
78
from ._apply_generic import flatten_class
@@ -24,6 +25,7 @@
2425
"flatten_class",
2526
"issubtype",
2627
"TypeMapError",
28+
"StuckException",
2729
"_EvalProxy",
2830
"_get_current_context",
2931
)

typemap/type_eval/_apply_generic.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@
1212
from . import _typing_inspect
1313

1414
if typing.TYPE_CHECKING:
15-
from typing import Any
15+
from typing import Any, Mapping
1616

1717

1818
@dataclasses.dataclass(frozen=True)
1919
class Boxed:
2020
cls: type[Any]
2121
bases: list[Boxed]
2222
args: dict[Any, Any]
23+
orig_cls: type[Any] | None = (
24+
None # Original class, before __init_subclass__ applied
25+
)
2326

2427
str_args: dict[str, Any] = dataclasses.field(init=False)
2528
mro: tuple[Boxed, ...] = dataclasses.field(init=False)
@@ -38,14 +41,22 @@ def __post_init__(self):
3841
object.__setattr__(
3942
self,
4043
"mro",
41-
_compute_mro(self),
44+
tuple(_compute_mro(self)),
4245
)
4346

47+
@property
48+
def canonical_cls(self):
49+
"""The class for the original boxing.
50+
51+
(Possibly a new one was created after __init_subclass__ applied.
52+
"""
53+
return self.orig_cls or self.cls
54+
4455
def alias_type(self):
4556
if self.args:
46-
return self.cls[*self.args.values()]
57+
return self.canonical_cls[*self.args.values()]
4758
else:
48-
return self.cls
59+
return self.canonical_cls
4960

5061
def __repr__(self):
5162
return f"Boxed<{self.cls} {self.args}>"
@@ -194,7 +205,7 @@ def make_func(
194205

195206
def get_annotations(
196207
obj: object,
197-
args: dict[str, object],
208+
args: Mapping[str, object],
198209
key: str = '__annotate__',
199210
annos_ok: bool = True,
200211
) -> Any | None:
@@ -224,7 +235,7 @@ def get_annotations(
224235
# Copy in any __type_params__ that aren't provided for, so that if
225236
# we have to eval, we have them.
226237
if params := getattr(obj, "__type_params__", None):
227-
args = args.copy()
238+
args = dict(args)
228239
for param in params:
229240
if str(param) not in args:
230241
args[str(param)] = param
@@ -263,9 +274,16 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
263274
# TODO: This annos_ok thing is a hack because processing
264275
# __annotations__ on methods broke stuff and I didn't want
265276
# to chase it down yet.
266-
if (
267-
rr := get_annotations(stuff, boxed.str_args, annos_ok=False)
268-
) is not None:
277+
try:
278+
rr = get_annotations(stuff, boxed.str_args, annos_ok=False)
279+
except _eval_typing.StuckException:
280+
# TODO: Either generate a GenericCallable or a
281+
# function with our own __annotate__ for this case
282+
# where we can't even fetch the signature without
283+
# trouble.
284+
rr = None
285+
286+
if rr is not None:
269287
local_fn = make_func(orig, rr)
270288
elif getattr(stuff, "__annotations__", None):
271289
# XXX: This is totally wrong; we still need to do

typemap/type_eval/_eval_operators.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import collections
22
import collections.abc
33
import contextlib
4+
import dataclasses
45
import functools
56
import inspect
67
import itertools
@@ -81,24 +82,34 @@ def _make_init_type(v):
8182
return typing.Literal[(v,)]
8283

8384

85+
def cached_box(cls, *, ctx):
86+
if str(cls).startswith('typemap.typing'):
87+
return _apply_generic.box(cls)
88+
if cls in ctx.box_cache:
89+
return ctx.box_cache[cls]
90+
ctx.box_cache[cls] = box = _apply_generic.box(cls)
91+
assert box.mro
92+
# if not all(b.mro for b in box.mro):
93+
# breakpoint()
94+
# assert all(b.mro for b in box.mro)
95+
96+
if new_box := _eval_init_subclass(box, ctx):
97+
ctx.box_cache[cls] = box = new_box
98+
return box
99+
100+
84101
def get_annotated_type_hints(cls, *, ctx, **kwargs):
85102
"""Get the type hints/quals for a cls annotated with definition site.
86103
87104
This traverses the mro and finds the definition site for each annotation.
88105
"""
89106

90-
# TODO: Cache the box (slash don't need it??)
91-
box = _apply_generic.box(cls)
107+
box = cached_box(cls, ctx=ctx)
92108

93109
hints = {}
94110
for abox in reversed(box.mro):
95111
acls = abox.alias_type()
96112

97-
if abox is box and (updated_cls := _eval_init_subclass(box, ctx)):
98-
# For the class itself, apply all UpdateClass from
99-
# ancesstors' __init_subclass__ to get the final type.
100-
abox = _apply_generic.box(updated_cls)
101-
102113
annos, _ = _apply_generic.get_local_defns(abox)
103114
for k, ty in annos.items():
104115
quals = set()
@@ -129,18 +140,12 @@ def get_annotated_type_hints(cls, *, ctx, **kwargs):
129140

130141

131142
def get_annotated_method_hints(cls, *, ctx):
132-
# TODO: Cache the box (slash don't need it??)
133-
box = _apply_generic.box(cls)
143+
box = cached_box(cls, ctx=ctx)
134144

135145
hints = {}
136146
for abox in reversed(box.mro):
137147
acls = abox.alias_type()
138148

139-
if abox is box and (updated_cls := _eval_init_subclass(box, ctx)):
140-
# For the class itself, apply all UpdateClass from
141-
# ancesstors' __init_subclass__ to get the final type.
142-
abox = _apply_generic.box(updated_cls)
143-
144149
_, dct = _apply_generic.get_local_defns(abox)
145150
for name, attr in dct.items():
146151
if isinstance(
@@ -167,25 +172,38 @@ def get_annotated_method_hints(cls, *, ctx):
167172

168173
def _eval_init_subclass(
169174
box: _apply_generic.Boxed, ctx: typing.Any
170-
) -> type | None:
175+
) -> _apply_generic.Boxed:
171176
"""Get type after all __init_subclass__ with UpdateClass are evaluated."""
172-
for abox in reversed(box.mro[1:]): # Skip the type itself
173-
if ms := _get_update_class_members(box.cls, abox.alias_type(), ctx=ctx):
174-
return _create_updated_class(box.cls, ms, ctx=ctx)
175-
176-
return None
177+
for abox in box.mro[1:]: # Skip the type itself
178+
with _child_context() as ctx:
179+
if ms := _get_update_class_members(
180+
box.cls, abox.alias_type(), ctx=ctx
181+
):
182+
nbox = _apply_generic.box(
183+
_create_updated_class(box.cls, ms, ctx=ctx)
184+
)
185+
# We want to preserve the original cls for Members output
186+
box = dataclasses.replace(nbox, orig_cls=box.canonical_cls)
187+
ctx.box_cache[box.cls] = box
188+
return box
177189

178190

179191
def _get_update_class_members(
180192
cls: type, base: type, ctx: typing.Any
181193
) -> list[Member] | None:
182-
if (
183-
(init_subclass := base.__dict__.get("__init_subclass__"))
184-
# XXX: We're using get_type_hints now to evaluate hints but
185-
# we should have our own generic infrastructure instead.
186-
# (I'm working on it -sully)
187-
and (init_subclass_annos := typing.get_type_hints(init_subclass))
188-
and (ret_annotation := init_subclass_annos.get("return"))
194+
init_subclass = base.__dict__.get("__init_subclass__")
195+
if not init_subclass:
196+
return None
197+
init_subclass = inspect.unwrap(init_subclass)
198+
199+
args = {}
200+
if type_params := getattr(init_subclass, '__type_params__', None):
201+
args[str(type_params[0])] = cls
202+
203+
init_subclass_annos = _apply_generic.get_annotations(init_subclass, args)
204+
205+
if init_subclass_annos and (
206+
ret_annotation := init_subclass_annos.get("return")
189207
):
190208
# Substitute the cls type var with the current class
191209
# This may not happen if cls is not generic!
@@ -211,12 +229,7 @@ def _get_update_class_members(
211229
)
212230

213231
# Evaluate the return annotation
214-
# Do it in a child context, so the evaluations are isolated. For
215-
# example, if the return annotation uses Attrs[MyClass], we want
216-
# Attrs[MyClass] to be evaluated with the updated class, not the
217-
# original.
218-
with _child_context() as ctx:
219-
evaled_ret = _eval_types(ret_annotation, ctx=ctx)
232+
evaled_ret = _eval_types(ret_annotation, ctx=ctx)
220233

221234
# If the result is an UpdateClass, return the members
222235
if (

typemap/type_eval/_eval_typing.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
__all__ = ("eval_typing",)
2626

2727

28+
class StuckException(Exception):
29+
"""Raised when a type operator receives a type variable argument."""
30+
31+
pass
32+
33+
2834
_eval_funcs: dict[type, typing.Callable[..., Any]] = {}
2935

3036

@@ -101,6 +107,10 @@ class EvalContext:
101107
typing.TypeAliasType | types.GenericAlias, typing.Any
102108
] = dataclasses.field(default_factory=dict)
103109

110+
box_cache: dict[typing.Any, _apply_generic.Boxed] = dataclasses.field(
111+
default_factory=dict
112+
)
113+
104114
# The typing.Any is really a types.FunctionType, but mypy gets
105115
# confused and wants to treat it as a MethodType.
106116
current_generic_alias: types.GenericAlias | typing.Any | None = None
@@ -123,7 +133,7 @@ def _ensure_context() -> typing.Iterator[EvalContext]:
123133
_current_context.set(ctx)
124134
ctx_set = True
125135
evaluator_token = nt.special_form_evaluator.set(
126-
lambda t: _eval_types(t, ctx)
136+
lambda t: _eval_types(t, _current_context.get()) # type: ignore[arg-type]
127137
)
128138

129139
try:
@@ -168,6 +178,7 @@ def _child_context() -> typing.Iterator[EvalContext]:
168178
recursive_type_alias=ctx.recursive_type_alias,
169179
known_recursive_types=ctx.known_recursive_types.copy(),
170180
current_generic_alias=ctx.current_generic_alias,
181+
box_cache=ctx.box_cache, # Not copied!
171182
)
172183
_current_context.set(child_ctx)
173184
yield child_ctx
@@ -394,6 +405,13 @@ def _eval_applied_class(obj: typing_GenericAlias, ctx: EvalContext):
394405
new_args = _eval_args(typing.get_args(obj), ctx)
395406

396407
if func := _eval_funcs.get(obj.__origin__):
408+
_tvars = (
409+
typing.TypeVar,
410+
typing.ParamSpec,
411+
typing.TypeVarTuple,
412+
)
413+
if any(isinstance(a, _tvars) for a in new_args):
414+
raise StuckException(obj)
397415
ret = func(*new_args, ctx=ctx)
398416
# return _eval_types(ret, ctx) # ???
399417
return ret

0 commit comments

Comments
 (0)