Skip to content

Commit 963e4ae

Browse files
hugues-affJukkaL
authored andcommitted
more enum-related speedups (#12032)
As a followup to #9394 address a few more O(n**2) behaviors caused by decomposing enums into unions of literals.
1 parent 914506b commit 963e4ae

File tree

4 files changed

+116
-16
lines changed

4 files changed

+116
-16
lines changed

mypy/meet.py

+30
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
6464
if isinstance(declared, UnionType):
6565
return make_simplified_union([narrow_declared_type(x, narrowed)
6666
for x in declared.relevant_items()])
67+
if is_enum_overlapping_union(declared, narrowed):
68+
return narrowed
6769
elif not is_overlapping_types(declared, narrowed,
6870
prohibit_none_typevar_overlap=True):
6971
if state.strict_optional:
@@ -137,6 +139,22 @@ def get_possible_variants(typ: Type) -> List[Type]:
137139
return [typ]
138140

139141

142+
def is_enum_overlapping_union(x: ProperType, y: ProperType) -> bool:
143+
"""Return True if x is an Enum, and y is an Union with at least one Literal from x"""
144+
return (
145+
isinstance(x, Instance) and x.type.is_enum and
146+
isinstance(y, UnionType) and
147+
any(isinstance(p, LiteralType) and x.type == p.fallback.type
148+
for p in (get_proper_type(z) for z in y.relevant_items()))
149+
)
150+
151+
152+
def is_literal_in_union(x: ProperType, y: ProperType) -> bool:
153+
"""Return True if x is a Literal and y is an Union that includes x"""
154+
return (isinstance(x, LiteralType) and isinstance(y, UnionType) and
155+
any(x == get_proper_type(z) for z in y.items))
156+
157+
140158
def is_overlapping_types(left: Type,
141159
right: Type,
142160
ignore_promotions: bool = False,
@@ -198,6 +216,18 @@ def _is_overlapping_types(left: Type, right: Type) -> bool:
198216
#
199217
# These checks will also handle the NoneType and UninhabitedType cases for us.
200218

219+
# enums are sometimes expanded into an Union of Literals
220+
# when that happens we want to make sure we treat the two as overlapping
221+
# and crucially, we want to do that *fast* in case the enum is large
222+
# so we do it before expanding variants below to avoid O(n**2) behavior
223+
if (
224+
is_enum_overlapping_union(left, right)
225+
or is_enum_overlapping_union(right, left)
226+
or is_literal_in_union(left, right)
227+
or is_literal_in_union(right, left)
228+
):
229+
return True
230+
201231
if (is_proper_subtype(left, right, ignore_promotions=ignore_promotions)
202232
or is_proper_subtype(right, left, ignore_promotions=ignore_promotions)):
203233
return True

mypy/sametypes.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Sequence
1+
from typing import Sequence, Tuple, Set, List
22

33
from mypy.types import (
44
Type, UnboundType, AnyType, NoneType, TupleType, TypedDictType,
55
UnionType, CallableType, TypeVarType, Instance, TypeVisitor, ErasedType,
66
Overloaded, PartialType, DeletedType, UninhabitedType, TypeType, LiteralType,
77
ProperType, get_proper_type, TypeAliasType, ParamSpecType, Parameters, UnpackType
88
)
9-
from mypy.typeops import tuple_fallback, make_simplified_union
9+
from mypy.typeops import tuple_fallback, make_simplified_union, is_simple_literal
1010

1111

1212
def is_same_type(left: Type, right: Type) -> bool:
@@ -49,6 +49,22 @@ def is_same_types(a1: Sequence[Type], a2: Sequence[Type]) -> bool:
4949
return True
5050

5151

52+
def _extract_literals(u: UnionType) -> Tuple[Set[Type], List[Type]]:
53+
"""Given a UnionType, separate out its items into a set of simple literals and a remainder list
54+
This is a useful helper to avoid O(n**2) behavior when comparing large unions, which can often
55+
result from large enums in contexts where type narrowing removes a small subset of entries.
56+
"""
57+
lit: Set[Type] = set()
58+
rem: List[Type] = []
59+
for i in u.relevant_items():
60+
i = get_proper_type(i)
61+
if is_simple_literal(i):
62+
lit.add(i)
63+
else:
64+
rem.append(i)
65+
return lit, rem
66+
67+
5268
class SameTypeVisitor(TypeVisitor[bool]):
5369
"""Visitor for checking whether two types are the 'same' type."""
5470

@@ -153,14 +169,20 @@ def visit_literal_type(self, left: LiteralType) -> bool:
153169

154170
def visit_union_type(self, left: UnionType) -> bool:
155171
if isinstance(self.right, UnionType):
172+
left_lit, left_rem = _extract_literals(left)
173+
right_lit, right_rem = _extract_literals(self.right)
174+
175+
if left_lit != right_lit:
176+
return False
177+
156178
# Check that everything in left is in right
157-
for left_item in left.items:
158-
if not any(is_same_type(left_item, right_item) for right_item in self.right.items):
179+
for left_item in left_rem:
180+
if not any(is_same_type(left_item, right_item) for right_item in right_rem):
159181
return False
160182

161183
# Check that everything in right is in left
162-
for right_item in self.right.items:
163-
if not any(is_same_type(right_item, left_item) for left_item in left.items):
184+
for right_item in right_rem:
185+
if not any(is_same_type(right_item, left_item) for left_item in left_rem):
164186
return False
165187

166188
return True

mypy/subtypes.py

+49-10
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,20 @@ def visit_overloaded(self, left: Overloaded) -> bool:
535535
return False
536536

537537
def visit_union_type(self, left: UnionType) -> bool:
538+
if isinstance(self.right, Instance):
539+
literal_types: Set[Instance] = set()
540+
# avoid redundant check for union of literals
541+
for item in left.relevant_items():
542+
item = get_proper_type(item)
543+
lit_type = mypy.typeops.simple_literal_type(item)
544+
if lit_type is not None:
545+
if lit_type in literal_types:
546+
continue
547+
literal_types.add(lit_type)
548+
item = lit_type
549+
if not self._is_subtype(item, self.orig_right):
550+
return False
551+
return True
538552
return all(self._is_subtype(item, self.orig_right) for item in left.items)
539553

540554
def visit_partial_type(self, left: PartialType) -> bool:
@@ -1199,6 +1213,27 @@ def report(*args: Any) -> None:
11991213
return applied
12001214

12011215

1216+
def try_restrict_literal_union(t: UnionType, s: Type) -> Optional[List[Type]]:
1217+
"""Return the items of t, excluding any occurrence of s, if and only if
1218+
- t only contains simple literals
1219+
- s is a simple literal
1220+
1221+
Otherwise, returns None
1222+
"""
1223+
ps = get_proper_type(s)
1224+
if not mypy.typeops.is_simple_literal(ps):
1225+
return None
1226+
1227+
new_items: List[Type] = []
1228+
for i in t.relevant_items():
1229+
pi = get_proper_type(i)
1230+
if not mypy.typeops.is_simple_literal(pi):
1231+
return None
1232+
if pi != ps:
1233+
new_items.append(i)
1234+
return new_items
1235+
1236+
12021237
def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False) -> Type:
12031238
"""Return t minus s for runtime type assertions.
12041239
@@ -1212,10 +1247,14 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False)
12121247
s = get_proper_type(s)
12131248

12141249
if isinstance(t, UnionType):
1215-
new_items = [restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
1216-
for item in t.relevant_items()
1217-
if (isinstance(get_proper_type(item), AnyType) or
1218-
not covers_at_runtime(item, s, ignore_promotions))]
1250+
new_items = try_restrict_literal_union(t, s)
1251+
if new_items is None:
1252+
new_items = [
1253+
restrict_subtype_away(item, s, ignore_promotions=ignore_promotions)
1254+
for item in t.relevant_items()
1255+
if (isinstance(get_proper_type(item), AnyType) or
1256+
not covers_at_runtime(item, s, ignore_promotions))
1257+
]
12191258
return UnionType.make_union(new_items)
12201259
elif covers_at_runtime(t, s, ignore_promotions):
12211260
return UninhabitedType()
@@ -1285,11 +1324,11 @@ def _is_proper_subtype(left: Type, right: Type, *,
12851324
right = get_proper_type(right)
12861325

12871326
if isinstance(right, UnionType) and not isinstance(left, UnionType):
1288-
return any([is_proper_subtype(orig_left, item,
1289-
ignore_promotions=ignore_promotions,
1290-
erase_instances=erase_instances,
1291-
keep_erased_types=keep_erased_types)
1292-
for item in right.items])
1327+
return any(is_proper_subtype(orig_left, item,
1328+
ignore_promotions=ignore_promotions,
1329+
erase_instances=erase_instances,
1330+
keep_erased_types=keep_erased_types)
1331+
for item in right.items)
12931332
return left.accept(ProperSubtypeVisitor(orig_right,
12941333
ignore_promotions=ignore_promotions,
12951334
erase_instances=erase_instances,
@@ -1495,7 +1534,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
14951534
return False
14961535

14971536
def visit_union_type(self, left: UnionType) -> bool:
1498-
return all([self._is_proper_subtype(item, self.orig_right) for item in left.items])
1537+
return all(self._is_proper_subtype(item, self.orig_right) for item in left.items)
14991538

15001539
def visit_partial_type(self, left: PartialType) -> bool:
15011540
# TODO: What's the right thing to do here?

mypy/typeops.py

+9
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,15 @@ def simple_literal_value_key(t: ProperType) -> Optional[Tuple[str, ...]]:
318318
return None
319319

320320

321+
def simple_literal_type(t: ProperType) -> Optional[Instance]:
322+
"""Extract the underlying fallback Instance type for a simple Literal"""
323+
if isinstance(t, Instance) and t.last_known_value is not None:
324+
t = t.last_known_value
325+
if isinstance(t, LiteralType):
326+
return t.fallback
327+
return None
328+
329+
321330
def is_simple_literal(t: ProperType) -> bool:
322331
"""Fast way to check if simple_literal_value_key() would return a non-None value."""
323332
if isinstance(t, LiteralType):

0 commit comments

Comments
 (0)