@@ -535,6 +535,20 @@ def visit_overloaded(self, left: Overloaded) -> bool:
535
535
return False
536
536
537
537
def visit_union_type (self , left : UnionType ) -> bool :
538
+ if isinstance (self .right , Instance ):
539
+ literal_types : Set [Instance ] = set ()
540
+ # avoid redundant check for union of literals
541
+ for item in left .relevant_items ():
542
+ item = get_proper_type (item )
543
+ lit_type = mypy .typeops .simple_literal_type (item )
544
+ if lit_type is not None :
545
+ if lit_type in literal_types :
546
+ continue
547
+ literal_types .add (lit_type )
548
+ item = lit_type
549
+ if not self ._is_subtype (item , self .orig_right ):
550
+ return False
551
+ return True
538
552
return all (self ._is_subtype (item , self .orig_right ) for item in left .items )
539
553
540
554
def visit_partial_type (self , left : PartialType ) -> bool :
@@ -1199,6 +1213,27 @@ def report(*args: Any) -> None:
1199
1213
return applied
1200
1214
1201
1215
1216
+ def try_restrict_literal_union (t : UnionType , s : Type ) -> Optional [List [Type ]]:
1217
+ """Return the items of t, excluding any occurrence of s, if and only if
1218
+ - t only contains simple literals
1219
+ - s is a simple literal
1220
+
1221
+ Otherwise, returns None
1222
+ """
1223
+ ps = get_proper_type (s )
1224
+ if not mypy .typeops .is_simple_literal (ps ):
1225
+ return None
1226
+
1227
+ new_items : List [Type ] = []
1228
+ for i in t .relevant_items ():
1229
+ pi = get_proper_type (i )
1230
+ if not mypy .typeops .is_simple_literal (pi ):
1231
+ return None
1232
+ if pi != ps :
1233
+ new_items .append (i )
1234
+ return new_items
1235
+
1236
+
1202
1237
def restrict_subtype_away (t : Type , s : Type , * , ignore_promotions : bool = False ) -> Type :
1203
1238
"""Return t minus s for runtime type assertions.
1204
1239
@@ -1212,10 +1247,14 @@ def restrict_subtype_away(t: Type, s: Type, *, ignore_promotions: bool = False)
1212
1247
s = get_proper_type (s )
1213
1248
1214
1249
if isinstance (t , UnionType ):
1215
- new_items = [restrict_subtype_away (item , s , ignore_promotions = ignore_promotions )
1216
- for item in t .relevant_items ()
1217
- if (isinstance (get_proper_type (item ), AnyType ) or
1218
- not covers_at_runtime (item , s , ignore_promotions ))]
1250
+ new_items = try_restrict_literal_union (t , s )
1251
+ if new_items is None :
1252
+ new_items = [
1253
+ restrict_subtype_away (item , s , ignore_promotions = ignore_promotions )
1254
+ for item in t .relevant_items ()
1255
+ if (isinstance (get_proper_type (item ), AnyType ) or
1256
+ not covers_at_runtime (item , s , ignore_promotions ))
1257
+ ]
1219
1258
return UnionType .make_union (new_items )
1220
1259
elif covers_at_runtime (t , s , ignore_promotions ):
1221
1260
return UninhabitedType ()
@@ -1285,11 +1324,11 @@ def _is_proper_subtype(left: Type, right: Type, *,
1285
1324
right = get_proper_type (right )
1286
1325
1287
1326
if isinstance (right , UnionType ) and not isinstance (left , UnionType ):
1288
- return any ([ is_proper_subtype (orig_left , item ,
1289
- ignore_promotions = ignore_promotions ,
1290
- erase_instances = erase_instances ,
1291
- keep_erased_types = keep_erased_types )
1292
- for item in right .items ] )
1327
+ return any (is_proper_subtype (orig_left , item ,
1328
+ ignore_promotions = ignore_promotions ,
1329
+ erase_instances = erase_instances ,
1330
+ keep_erased_types = keep_erased_types )
1331
+ for item in right .items )
1293
1332
return left .accept (ProperSubtypeVisitor (orig_right ,
1294
1333
ignore_promotions = ignore_promotions ,
1295
1334
erase_instances = erase_instances ,
@@ -1495,7 +1534,7 @@ def visit_overloaded(self, left: Overloaded) -> bool:
1495
1534
return False
1496
1535
1497
1536
def visit_union_type (self , left : UnionType ) -> bool :
1498
- return all ([ self ._is_proper_subtype (item , self .orig_right ) for item in left .items ] )
1537
+ return all (self ._is_proper_subtype (item , self .orig_right ) for item in left .items )
1499
1538
1500
1539
def visit_partial_type (self , left : PartialType ) -> bool :
1501
1540
# TODO: What's the right thing to do here?
0 commit comments