Skip to content

Commit 454989f

Browse files
authored
[mypyc] Using UnboundedType to access class object of a type annotation. (python#18874)
Fixes mypyc/mypyc#1087. This fix handles cases where type annotation is nested inside an imported module (like an inner class) or imported from a different module.
1 parent 82d9477 commit 454989f

File tree

4 files changed

+74
-7
lines changed

4 files changed

+74
-7
lines changed

mypyc/irbuild/classdef.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ def add_non_ext_class_attr_ann(
634634
if builder.current_module == type_info.module_name and stmt.line < type_info.line:
635635
typ = builder.load_str(type_info.fullname)
636636
else:
637-
typ = load_type(builder, type_info, stmt.line)
637+
typ = load_type(builder, type_info, stmt.unanalyzed_type, stmt.line)
638638

639639
if typ is None:
640640
# FIXME: if get_type_info is not provided, don't fall back to stmt.type?
@@ -650,7 +650,7 @@ def add_non_ext_class_attr_ann(
650650
# actually a forward reference due to the __annotations__ future?
651651
typ = builder.load_str(stmt.unanalyzed_type.original_str_expr)
652652
elif isinstance(ann_type, Instance):
653-
typ = load_type(builder, ann_type.type, stmt.line)
653+
typ = load_type(builder, ann_type.type, stmt.unanalyzed_type, stmt.line)
654654
else:
655655
typ = builder.add(LoadAddress(type_object_op.type, type_object_op.src, stmt.line))
656656

mypyc/irbuild/function.py

+39-5
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
TypeInfo,
3030
Var,
3131
)
32-
from mypy.types import CallableType, get_proper_type
32+
from mypy.types import CallableType, Type, UnboundType, get_proper_type
3333
from mypyc.common import LAMBDA_NAME, PROPSET_PREFIX, SELF_NAME
3434
from mypyc.ir.class_ir import ClassIR, NonExtClassInfo
3535
from mypyc.ir.func_ir import (
@@ -802,15 +802,49 @@ def get_func_target(builder: IRBuilder, fdef: FuncDef) -> AssignmentTarget:
802802
return builder.add_local_reg(fdef, object_rprimitive)
803803

804804

805-
def load_type(builder: IRBuilder, typ: TypeInfo, line: int) -> Value:
805+
# This function still does not support the following imports.
806+
# import json as _json
807+
# from json import decoder
808+
# Using either _json.JSONDecoder or decoder.JSONDecoder as a type hint for a dataclass field will fail.
809+
# See issue mypyc/mypyc#1099.
810+
def load_type(builder: IRBuilder, typ: TypeInfo, unbounded_type: Type | None, line: int) -> Value:
811+
# typ.fullname contains the module where the class object was defined. However, it is possible
812+
# that the class object's module was not imported in the file currently being compiled. So, we
813+
# use unbounded_type.name (if provided by caller) to load the class object through one of the
814+
# imported modules.
815+
# Example: for `json.JSONDecoder`, typ.fullname is `json.decoder.JSONDecoder` but the Python
816+
# file may import `json` not `json.decoder`.
817+
# Another corner case: The Python file being compiled imports mod1 and has a type hint
818+
# `mod1.OuterClass.InnerClass`. But, mod1/__init__.py might import OuterClass like this:
819+
# `from mod2.mod3 import OuterClass`. In this case, typ.fullname is
820+
# `mod2.mod3.OuterClass.InnerClass` and `unbounded_type.name` is `mod1.OuterClass.InnerClass`.
821+
# So, we must use unbounded_type.name to load the class object.
822+
# See issue mypyc/mypyc#1087.
823+
load_attr_path = (
824+
unbounded_type.name if isinstance(unbounded_type, UnboundType) else typ.fullname
825+
).removesuffix(f".{typ.name}")
806826
if typ in builder.mapper.type_to_ir:
807827
class_ir = builder.mapper.type_to_ir[typ]
808828
class_obj = builder.builder.get_native_type(class_ir)
809829
elif typ.fullname in builtin_names:
810830
builtin_addr_type, src = builtin_names[typ.fullname]
811831
class_obj = builder.add(LoadAddress(builtin_addr_type, src, line))
812-
elif typ.module_name in builder.imports:
813-
loaded_module = builder.load_module(typ.module_name)
832+
# This elif-condition finds the longest import that matches the load_attr_path.
833+
elif module_name := max(
834+
(i for i in builder.imports if load_attr_path == i or load_attr_path.startswith(f"{i}.")),
835+
default="",
836+
key=len,
837+
):
838+
# Load the imported module.
839+
loaded_module = builder.load_module(module_name)
840+
# Recursively load attributes of the imported module. These may be submodules, classes or
841+
# any other object.
842+
for attr in (
843+
load_attr_path.removeprefix(f"{module_name}.").split(".")
844+
if load_attr_path != module_name
845+
else []
846+
):
847+
loaded_module = builder.py_get_attr(loaded_module, attr, line)
814848
class_obj = builder.builder.get_attr(
815849
loaded_module, typ.name, object_rprimitive, line, borrow=False
816850
)
@@ -1039,7 +1073,7 @@ def maybe_insert_into_registry_dict(builder: IRBuilder, fitem: FuncDef) -> None:
10391073
)
10401074
registry = load_singledispatch_registry(builder, dispatch_func_obj, line)
10411075
for typ in types:
1042-
loaded_type = load_type(builder, typ, line)
1076+
loaded_type = load_type(builder, typ, None, line)
10431077
builder.primitive_op(dict_set_item_op, [registry, loaded_type, to_insert], line)
10441078
dispatch_cache = builder.builder.get_attr(
10451079
dispatch_func_obj, "dispatch_cache", dict_rprimitive, line

mypyc/test-data/commandline.test

+28
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,31 @@ print("imported foo")
261261
importing...
262262
imported foo
263263
done
264+
265+
[case testImportFromInitPy]
266+
# cmd: foo.py
267+
import foo
268+
269+
[file pkg2/__init__.py]
270+
271+
[file pkg2/mod2.py]
272+
class A:
273+
class B:
274+
pass
275+
276+
[file pkg1/__init__.py]
277+
from pkg2.mod2 import A
278+
279+
[file foo.py]
280+
import pkg1
281+
from typing import TypedDict
282+
283+
class Eggs(TypedDict):
284+
obj1: pkg1.A.B
285+
286+
print(type(Eggs(obj1=pkg1.A.B())["obj1"]).__name__)
287+
print(type(Eggs(obj1=pkg1.A.B())["obj1"]).__module__)
288+
289+
[out]
290+
B
291+
pkg2.mod2

mypyc/test-data/run-classes.test

+5
Original file line numberDiff line numberDiff line change
@@ -78,17 +78,22 @@ assert hasattr(c, 'x')
7878

7979
[case testTypedDictWithFields]
8080
import collections
81+
import json
8182
from typing import TypedDict
8283
class C(TypedDict):
8384
x: collections.deque
85+
spam: json.JSONDecoder
8486
[file driver.py]
8587
from native import C
8688
from collections import deque
89+
from json import JSONDecoder
8790

8891
print(C.__annotations__["x"] is deque)
92+
print(C.__annotations__["spam"] is JSONDecoder)
8993
[typing fixtures/typing-full.pyi]
9094
[out]
9195
True
96+
True
9297

9398
[case testClassWithDeletableAttributes]
9499
from typing import Any, cast

0 commit comments

Comments
 (0)