Skip to content

Commit b8ea347

Browse files
committed
Implement.
1 parent a932dd0 commit b8ea347

2 files changed

Lines changed: 39 additions & 1 deletion

File tree

typemap/type_eval/_apply_generic.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,14 @@ def _resolved_function_signature(func, args):
291291
return sig
292292

293293

294-
def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
294+
def get_local_defns(
295+
boxed: Boxed,
296+
) -> tuple[
297+
dict[str, Any],
298+
dict[
299+
str, types.FunctionType | classmethod | staticmethod | WrappedOverloaded
300+
],
301+
]:
295302
from typemap.typing import GenericCallable
296303

297304
annos: dict[str, Any] = {}
@@ -327,6 +334,8 @@ def get_local_defns(boxed: Boxed) -> tuple[dict[str, Any], dict[str, Any]]:
327334
# XXX: This is totally wrong; we still need to do
328335
# substitute in class vars
329336
local_fn = stuff
337+
elif overloaded := _is_overloaded_function(stuff):
338+
local_fn = overloaded
330339

331340
# If we got stuck, we build a GenericCallable that
332341
# computes the type once it has been given type
@@ -370,6 +379,23 @@ def lam(*vs):
370379
return annos, dct
371380

372381

382+
@dataclasses.dataclass(frozen=True)
383+
class WrappedOverloaded:
384+
functions: tuple[types.FunctionType, ...]
385+
386+
387+
def _is_overloaded_function(func):
388+
module_overload_registry = typing._overload_registry[func.__module__]
389+
if not module_overload_registry:
390+
return None
391+
392+
func_overload_registry = module_overload_registry[func.__qualname__]
393+
if not func_overload_registry:
394+
return
395+
396+
return WrappedOverloaded(tuple(func_overload_registry.values()))
397+
398+
373399
def flatten_class_new_proto(cls: type) -> type:
374400
# This is a hacky version of flatten_class that works by using
375401
# NewProtocol on Members!

typemap/type_eval/_eval_operators.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
Member,
4040
Members,
4141
NewProtocol,
42+
Overloaded,
4243
Param,
4344
RaiseError,
4445
Slice,
@@ -169,6 +170,17 @@ def get_annotated_method_hints(cls, *, ctx):
169170
object,
170171
acls,
171172
)
173+
elif isinstance(attr, _apply_generic.WrappedOverloaded):
174+
overloads = [
175+
_function_type(_eval_types(of, ctx), receiver_type=acls)
176+
for of in attr.functions
177+
]
178+
hints[name] = (
179+
Overloaded[*overloads],
180+
("ClassVar",),
181+
object,
182+
acls,
183+
)
172184

173185
return hints
174186

0 commit comments

Comments
 (0)