@@ -4307,7 +4307,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
4307
4307
with self .binder .frame_context (can_skip = True , fall_through = 4 ):
4308
4308
typ = s .types [i ]
4309
4309
if typ :
4310
- t = self .check_except_handler_test (typ )
4310
+ t = self .check_except_handler_test (typ , s . is_star )
4311
4311
var = s .vars [i ]
4312
4312
if var :
4313
4313
# To support local variables, we make this a definition line,
@@ -4327,7 +4327,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
4327
4327
if s .else_body :
4328
4328
self .accept (s .else_body )
4329
4329
4330
- def check_except_handler_test (self , n : Expression ) -> Type :
4330
+ def check_except_handler_test (self , n : Expression , is_star : bool ) -> Type :
4331
4331
"""Type check an exception handler test clause."""
4332
4332
typ = self .expr_checker .accept (n )
4333
4333
@@ -4343,22 +4343,47 @@ def check_except_handler_test(self, n: Expression) -> Type:
4343
4343
item = ttype .items [0 ]
4344
4344
if not item .is_type_obj ():
4345
4345
self .fail (message_registry .INVALID_EXCEPTION_TYPE , n )
4346
- return AnyType ( TypeOfAny . from_error )
4347
- exc_type = item .ret_type
4346
+ return self . default_exception_type ( is_star )
4347
+ exc_type = erase_typevars ( item .ret_type )
4348
4348
elif isinstance (ttype , TypeType ):
4349
4349
exc_type = ttype .item
4350
4350
else :
4351
4351
self .fail (message_registry .INVALID_EXCEPTION_TYPE , n )
4352
- return AnyType ( TypeOfAny . from_error )
4352
+ return self . default_exception_type ( is_star )
4353
4353
4354
4354
if not is_subtype (exc_type , self .named_type ("builtins.BaseException" )):
4355
4355
self .fail (message_registry .INVALID_EXCEPTION_TYPE , n )
4356
- return AnyType ( TypeOfAny . from_error )
4356
+ return self . default_exception_type ( is_star )
4357
4357
4358
4358
all_types .append (exc_type )
4359
4359
4360
+ if is_star :
4361
+ new_all_types : list [Type ] = []
4362
+ for typ in all_types :
4363
+ if is_proper_subtype (typ , self .named_type ("builtins.BaseExceptionGroup" )):
4364
+ self .fail (message_registry .INVALID_EXCEPTION_GROUP , n )
4365
+ new_all_types .append (AnyType (TypeOfAny .from_error ))
4366
+ else :
4367
+ new_all_types .append (typ )
4368
+ return self .wrap_exception_group (new_all_types )
4360
4369
return make_simplified_union (all_types )
4361
4370
4371
+ def default_exception_type (self , is_star : bool ) -> Type :
4372
+ """Exception type to return in case of a previous type error."""
4373
+ any_type = AnyType (TypeOfAny .from_error )
4374
+ if is_star :
4375
+ return self .named_generic_type ("builtins.ExceptionGroup" , [any_type ])
4376
+ return any_type
4377
+
4378
+ def wrap_exception_group (self , types : Sequence [Type ]) -> Type :
4379
+ """Transform except* variable type into an appropriate exception group."""
4380
+ arg = make_simplified_union (types )
4381
+ if is_subtype (arg , self .named_type ("builtins.Exception" )):
4382
+ base = "builtins.ExceptionGroup"
4383
+ else :
4384
+ base = "builtins.BaseExceptionGroup"
4385
+ return self .named_generic_type (base , [arg ])
4386
+
4362
4387
def get_types_from_except_handler (self , typ : Type , n : Expression ) -> list [Type ]:
4363
4388
"""Helper for check_except_handler_test to retrieve handler types."""
4364
4389
typ = get_proper_type (typ )
0 commit comments