Skip to content

Commit 719cef9

Browse files
ilevkivskyisvalentin
authored andcommitted
Add support for exception groups and except* (#14020)
Ref #12840 It looks like from the point of view of type checking support is quite easy. Mypyc support however requires some actual work, so I don't include it in this PR.
1 parent 91b6fc3 commit 719cef9

File tree

9 files changed

+103
-14
lines changed

9 files changed

+103
-14
lines changed

mypy/checker.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -4307,7 +4307,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
43074307
with self.binder.frame_context(can_skip=True, fall_through=4):
43084308
typ = s.types[i]
43094309
if typ:
4310-
t = self.check_except_handler_test(typ)
4310+
t = self.check_except_handler_test(typ, s.is_star)
43114311
var = s.vars[i]
43124312
if var:
43134313
# To support local variables, we make this a definition line,
@@ -4327,7 +4327,7 @@ def visit_try_without_finally(self, s: TryStmt, try_frame: bool) -> None:
43274327
if s.else_body:
43284328
self.accept(s.else_body)
43294329

4330-
def check_except_handler_test(self, n: Expression) -> Type:
4330+
def check_except_handler_test(self, n: Expression, is_star: bool) -> Type:
43314331
"""Type check an exception handler test clause."""
43324332
typ = self.expr_checker.accept(n)
43334333

@@ -4343,22 +4343,47 @@ def check_except_handler_test(self, n: Expression) -> Type:
43434343
item = ttype.items[0]
43444344
if not item.is_type_obj():
43454345
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
4346-
return AnyType(TypeOfAny.from_error)
4347-
exc_type = item.ret_type
4346+
return self.default_exception_type(is_star)
4347+
exc_type = erase_typevars(item.ret_type)
43484348
elif isinstance(ttype, TypeType):
43494349
exc_type = ttype.item
43504350
else:
43514351
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
4352-
return AnyType(TypeOfAny.from_error)
4352+
return self.default_exception_type(is_star)
43534353

43544354
if not is_subtype(exc_type, self.named_type("builtins.BaseException")):
43554355
self.fail(message_registry.INVALID_EXCEPTION_TYPE, n)
4356-
return AnyType(TypeOfAny.from_error)
4356+
return self.default_exception_type(is_star)
43574357

43584358
all_types.append(exc_type)
43594359

4360+
if is_star:
4361+
new_all_types: list[Type] = []
4362+
for typ in all_types:
4363+
if is_proper_subtype(typ, self.named_type("builtins.BaseExceptionGroup")):
4364+
self.fail(message_registry.INVALID_EXCEPTION_GROUP, n)
4365+
new_all_types.append(AnyType(TypeOfAny.from_error))
4366+
else:
4367+
new_all_types.append(typ)
4368+
return self.wrap_exception_group(new_all_types)
43604369
return make_simplified_union(all_types)
43614370

4371+
def default_exception_type(self, is_star: bool) -> Type:
4372+
"""Exception type to return in case of a previous type error."""
4373+
any_type = AnyType(TypeOfAny.from_error)
4374+
if is_star:
4375+
return self.named_generic_type("builtins.ExceptionGroup", [any_type])
4376+
return any_type
4377+
4378+
def wrap_exception_group(self, types: Sequence[Type]) -> Type:
4379+
"""Transform except* variable type into an appropriate exception group."""
4380+
arg = make_simplified_union(types)
4381+
if is_subtype(arg, self.named_type("builtins.Exception")):
4382+
base = "builtins.ExceptionGroup"
4383+
else:
4384+
base = "builtins.BaseExceptionGroup"
4385+
return self.named_generic_type(base, [arg])
4386+
43624387
def get_types_from_except_handler(self, typ: Type, n: Expression) -> list[Type]:
43634388
"""Helper for check_except_handler_test to retrieve handler types."""
43644389
typ = get_proper_type(typ)

mypy/fastparse.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,6 @@ def visit_Try(self, n: ast3.Try) -> TryStmt:
12541254
return self.set_line(node, n)
12551255

12561256
def visit_TryStar(self, n: TryStar) -> TryStmt:
1257-
# TODO: we treat TryStar exactly like Try, which makes mypy not crash. See #12840
12581257
vs = [
12591258
self.set_line(NameExpr(h.name), h) if h.name is not None else None for h in n.handlers
12601259
]
@@ -1269,6 +1268,7 @@ def visit_TryStar(self, n: TryStar) -> TryStmt:
12691268
self.as_block(n.orelse, n.lineno),
12701269
self.as_block(n.finalbody, n.lineno),
12711270
)
1271+
node.is_star = True
12721272
return self.set_line(node, n)
12731273

12741274
# Assert(expr test, expr? msg)

mypy/message_registry.py

+3
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,9 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
4444
NO_RETURN_EXPECTED: Final = ErrorMessage("Return statement in function which does not return")
4545
INVALID_EXCEPTION: Final = ErrorMessage("Exception must be derived from BaseException")
4646
INVALID_EXCEPTION_TYPE: Final = ErrorMessage("Exception type must be derived from BaseException")
47+
INVALID_EXCEPTION_GROUP: Final = ErrorMessage(
48+
"Exception type in except* cannot derive from BaseExceptionGroup"
49+
)
4750
RETURN_IN_ASYNC_GENERATOR: Final = ErrorMessage(
4851
'"return" with value in async generator is not allowed'
4952
)

mypy/nodes.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1485,7 +1485,7 @@ def accept(self, visitor: StatementVisitor[T]) -> T:
14851485

14861486

14871487
class TryStmt(Statement):
1488-
__slots__ = ("body", "types", "vars", "handlers", "else_body", "finally_body")
1488+
__slots__ = ("body", "types", "vars", "handlers", "else_body", "finally_body", "is_star")
14891489

14901490
body: Block # Try body
14911491
# Plain 'except:' also possible
@@ -1494,6 +1494,8 @@ class TryStmt(Statement):
14941494
handlers: list[Block] # Except bodies
14951495
else_body: Block | None
14961496
finally_body: Block | None
1497+
# Whether this is try ... except* (added in Python 3.11)
1498+
is_star: bool
14971499

14981500
def __init__(
14991501
self,
@@ -1511,6 +1513,7 @@ def __init__(
15111513
self.handlers = handlers
15121514
self.else_body = else_body
15131515
self.finally_body = finally_body
1516+
self.is_star = False
15141517

15151518
def accept(self, visitor: StatementVisitor[T]) -> T:
15161519
return visitor.visit_try_stmt(self)

mypy/strconv.py

+2
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ def visit_del_stmt(self, o: mypy.nodes.DelStmt) -> str:
276276

277277
def visit_try_stmt(self, o: mypy.nodes.TryStmt) -> str:
278278
a: list[Any] = [o.body]
279+
if o.is_star:
280+
a.append("*")
279281

280282
for i in range(len(o.vars)):
281283
a.append(o.types[i])

mypy/treetransform.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,16 @@ def visit_raise_stmt(self, node: RaiseStmt) -> RaiseStmt:
373373
return RaiseStmt(self.optional_expr(node.expr), self.optional_expr(node.from_expr))
374374

375375
def visit_try_stmt(self, node: TryStmt) -> TryStmt:
376-
return TryStmt(
376+
new = TryStmt(
377377
self.block(node.body),
378378
self.optional_names(node.vars),
379379
self.optional_expressions(node.types),
380380
self.blocks(node.handlers),
381381
self.optional_block(node.else_body),
382382
self.optional_block(node.finally_body),
383383
)
384+
new.is_star = node.is_star
385+
return new
384386

385387
def visit_with_stmt(self, node: WithStmt) -> WithStmt:
386388
new = WithStmt(

mypyc/irbuild/statement.py

+2
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,8 @@ def transform_try_stmt(builder: IRBuilder, t: TryStmt) -> None:
616616
# constructs that we compile separately. When we have a
617617
# try/except/else/finally, we treat the try/except/else as the
618618
# body of a try/finally block.
619+
if t.is_star:
620+
builder.error("Exception groups and except* cannot be compiled yet", t.line)
619621
if t.finally_body:
620622

621623
def transform_try_body() -> None:

test-data/unit/check-python311.test

+49-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,53 @@
1-
[case testTryStarDoesNotCrash]
1+
[case testTryStarSimple]
22
try:
33
pass
44
except* Exception as e:
5-
reveal_type(e) # N: Revealed type is "builtins.Exception"
5+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.Exception]"
6+
[builtins fixtures/exception.pyi]
7+
8+
[case testTryStarMultiple]
9+
try:
10+
pass
11+
except* Exception as e:
12+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.Exception]"
13+
except* RuntimeError as e:
14+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[builtins.RuntimeError]"
15+
[builtins fixtures/exception.pyi]
16+
17+
[case testTryStarBase]
18+
try:
19+
pass
20+
except* BaseException as e:
21+
reveal_type(e) # N: Revealed type is "builtins.BaseExceptionGroup[builtins.BaseException]"
22+
[builtins fixtures/exception.pyi]
23+
24+
[case testTryStarTuple]
25+
class Custom(Exception): ...
26+
27+
try:
28+
pass
29+
except* (RuntimeError, Custom) as e:
30+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Union[builtins.RuntimeError, __main__.Custom]]"
31+
[builtins fixtures/exception.pyi]
32+
33+
[case testTryStarInvalidType]
34+
class Bad: ...
35+
try:
36+
pass
37+
except* (RuntimeError, Bad) as e: # E: Exception type must be derived from BaseException
38+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Any]"
39+
[builtins fixtures/exception.pyi]
40+
41+
[case testTryStarGroupInvalid]
42+
try:
43+
pass
44+
except* ExceptionGroup as e: # E: Exception type in except* cannot derive from BaseExceptionGroup
45+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Any]"
46+
[builtins fixtures/exception.pyi]
47+
48+
[case testTryStarGroupInvalidTuple]
49+
try:
50+
pass
51+
except* (RuntimeError, ExceptionGroup) as e: # E: Exception type in except* cannot derive from BaseExceptionGroup
52+
reveal_type(e) # N: Revealed type is "builtins.ExceptionGroup[Union[builtins.RuntimeError, Any]]"
653
[builtins fixtures/exception.pyi]

test-data/unit/fixtures/exception.pyi

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,28 @@
1+
import sys
12
from typing import Generic, TypeVar
23
T = TypeVar('T')
34

45
class object:
56
def __init__(self): pass
67

78
class type: pass
8-
class tuple(Generic[T]): pass
9+
class tuple(Generic[T]):
10+
def __ge__(self, other: object) -> bool: ...
911
class function: pass
1012
class int: pass
1113
class str: pass
1214
class unicode: pass
1315
class bool: pass
1416
class ellipsis: pass
1517

16-
# Note: this is a slight simplification. In Python 2, the inheritance hierarchy
17-
# is actually Exception -> StandardError -> RuntimeError -> ...
1818
class BaseException:
1919
def __init__(self, *args: object) -> None: ...
2020
class Exception(BaseException): pass
2121
class RuntimeError(Exception): pass
2222
class NotImplementedError(RuntimeError): pass
2323

24+
if sys.version_info >= (3, 11):
25+
_BT_co = TypeVar("_BT_co", bound=BaseException, covariant=True)
26+
_T_co = TypeVar("_T_co", bound=Exception, covariant=True)
27+
class BaseExceptionGroup(BaseException, Generic[_BT_co]): ...
28+
class ExceptionGroup(BaseExceptionGroup[_T_co], Exception): ...

0 commit comments

Comments
 (0)