Skip to content

Commit 9944d5f

Browse files
authored
[mypyc] Support iterating over a TypedDict (#14747)
An optimization to make iterating over dict.keys(), dict.values() and dict.items() faster caused mypyc to crash while compiling a TypedDict. This commit fixes `Builder.get_dict_base_type` to properly handle `TypedDictType`. Fixes mypyc/mypyc#869.
1 parent 1a8ea61 commit 9944d5f

File tree

4 files changed

+112
-5
lines changed

4 files changed

+112
-5
lines changed

Diff for: mypyc/irbuild/builder.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
ProperType,
5353
TupleType,
5454
Type,
55+
TypedDictType,
5556
TypeOfAny,
5657
UninhabitedType,
5758
UnionType,
@@ -913,8 +914,12 @@ def get_dict_base_type(self, expr: Expression) -> list[Instance]:
913914

914915
dict_types = []
915916
for t in types:
916-
assert isinstance(t, Instance), t
917-
dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict")
917+
if isinstance(t, TypedDictType):
918+
t = t.fallback
919+
dict_base = next(base for base in t.type.mro if base.fullname == "typing.Mapping")
920+
else:
921+
assert isinstance(t, Instance), t
922+
dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict")
918923
dict_types.append(map_instance_to_supertype(t, dict_base))
919924
return dict_types
920925

Diff for: mypyc/test-data/irbuild-dict.test

+69
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,12 @@ L0:
219219

220220
[case testDictIterationMethods]
221221
from typing import Dict, Union
222+
from typing_extensions import TypedDict
223+
224+
class Person(TypedDict):
225+
name: str
226+
age: int
227+
222228
def print_dict_methods(d1: Dict[int, int], d2: Dict[int, int]) -> None:
223229
for v in d1.values():
224230
if v in d2:
@@ -229,6 +235,10 @@ def union_of_dicts(d: Union[Dict[str, int], Dict[str, str]]) -> None:
229235
new = {}
230236
for k, v in d.items():
231237
new[k] = int(v)
238+
def typeddict(d: Person) -> None:
239+
for k, v in d.items():
240+
if k == "name":
241+
name = v
232242
[out]
233243
def print_dict_methods(d1, d2):
234244
d1, d2 :: dict
@@ -370,6 +380,65 @@ L4:
370380
r19 = CPy_NoErrOccured()
371381
L5:
372382
return 1
383+
def typeddict(d):
384+
d :: dict
385+
r0 :: short_int
386+
r1 :: native_int
387+
r2 :: short_int
388+
r3 :: object
389+
r4 :: tuple[bool, short_int, object, object]
390+
r5 :: short_int
391+
r6 :: bool
392+
r7, r8 :: object
393+
r9, k :: str
394+
v :: object
395+
r10 :: str
396+
r11 :: int32
397+
r12 :: bit
398+
r13 :: object
399+
r14, r15, r16 :: bit
400+
name :: object
401+
r17, r18 :: bit
402+
L0:
403+
r0 = 0
404+
r1 = PyDict_Size(d)
405+
r2 = r1 << 1
406+
r3 = CPyDict_GetItemsIter(d)
407+
L1:
408+
r4 = CPyDict_NextItem(r3, r0)
409+
r5 = r4[1]
410+
r0 = r5
411+
r6 = r4[0]
412+
if r6 goto L2 else goto L9 :: bool
413+
L2:
414+
r7 = r4[2]
415+
r8 = r4[3]
416+
r9 = cast(str, r7)
417+
k = r9
418+
v = r8
419+
r10 = 'name'
420+
r11 = PyUnicode_Compare(k, r10)
421+
r12 = r11 == -1
422+
if r12 goto L3 else goto L5 :: bool
423+
L3:
424+
r13 = PyErr_Occurred()
425+
r14 = r13 != 0
426+
if r14 goto L4 else goto L5 :: bool
427+
L4:
428+
r15 = CPy_KeepPropagating()
429+
L5:
430+
r16 = r11 == 0
431+
if r16 goto L6 else goto L7 :: bool
432+
L6:
433+
name = v
434+
L7:
435+
L8:
436+
r17 = CPyDict_CheckSize(d, r2)
437+
goto L1
438+
L9:
439+
r18 = CPy_NoErrOccured()
440+
L10:
441+
return 1
373442

374443
[case testDictLoadAddress]
375444
def f() -> None:

Diff for: mypyc/test-data/run-dicts.test

+32-2
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,13 @@ assert get_content_set(od) == ({1, 3}, {2, 4}, {(1, 2), (3, 4)})
9595
[typing fixtures/typing-full.pyi]
9696

9797
[case testDictIterationMethodsRun]
98-
from typing import Dict
98+
from typing import Dict, Union
99+
from typing_extensions import TypedDict
100+
101+
class ExtensionDict(TypedDict):
102+
python: str
103+
c: str
104+
99105
def print_dict_methods(d1: Dict[int, int],
100106
d2: Dict[int, int],
101107
d3: Dict[int, int]) -> None:
@@ -107,13 +113,27 @@ def print_dict_methods(d1: Dict[int, int],
107113
for v in d3.values():
108114
print(v)
109115

116+
def print_dict_methods_special(d1: Union[Dict[int, int], Dict[str, str]],
117+
d2: ExtensionDict) -> None:
118+
for k in d1.keys():
119+
print(k)
120+
for k, v in d1.items():
121+
print(k)
122+
print(v)
123+
for v2 in d2.values():
124+
print(v2)
125+
for k2, v2 in d2.items():
126+
print(k2)
127+
print(v2)
128+
129+
110130
def clear_during_iter(d: Dict[int, int]) -> None:
111131
for k in d:
112132
d.clear()
113133

114134
class Custom(Dict[int, int]): pass
115135
[file driver.py]
116-
from native import print_dict_methods, Custom, clear_during_iter
136+
from native import print_dict_methods, print_dict_methods_special, Custom, clear_during_iter
117137
from collections import OrderedDict
118138
print_dict_methods({}, {}, {})
119139
print_dict_methods({1: 2}, {3: 4, 5: 6}, {7: 8})
@@ -124,6 +144,7 @@ print('==')
124144
d = OrderedDict([(1, 2), (3, 4)])
125145
print_dict_methods(d, d, d)
126146
print('==')
147+
print_dict_methods_special({1: 2}, {"python": ".py", "c": ".c"})
127148
d.move_to_end(1)
128149
print_dict_methods(d, d, d)
129150
clear_during_iter({}) # OK
@@ -185,6 +206,15 @@ else:
185206
2
186207
4
187208
==
209+
1
210+
1
211+
2
212+
.py
213+
.c
214+
python
215+
.py
216+
c
217+
.c
188218
3
189219
1
190220
3

Diff for: test-data/unit/lib-stub/typing_extensions.pyi

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import typing
2-
from typing import Any, Mapping, Iterator, NoReturn as NoReturn, Dict, Type
2+
from typing import Any, Mapping, Iterable, Iterator, NoReturn as NoReturn, Dict, Tuple, Type
33
from typing import TYPE_CHECKING as TYPE_CHECKING
44
from typing import NewType as NewType, overload as overload
55

@@ -50,6 +50,9 @@ class _TypedDict(Mapping[str, object]):
5050
# Mypy expects that 'default' has a type variable type.
5151
def pop(self, k: NoReturn, default: _T = ...) -> object: ...
5252
def update(self: _T, __m: _T) -> None: ...
53+
def items(self) -> Iterable[Tuple[str, object]]: ...
54+
def keys(self) -> Iterable[str]: ...
55+
def values(self) -> Iterable[object]: ...
5356
if sys.version_info < (3, 0):
5457
def has_key(self, k: str) -> bool: ...
5558
def __delitem__(self, k: NoReturn) -> None: ...

0 commit comments

Comments
 (0)