1
1
from __future__ import annotations
2
2
3
- from collections .abc import Sequence
3
+ from collections .abc import Callable , Sequence
4
4
from functools import partial
5
5
from inspect import getmro , isclass
6
- from typing import Any , Callable , Generic , Tuple , Type , TypeVar , Union , cast
6
+ from typing import TYPE_CHECKING , Any , Generic , Type , TypeVar , cast , overload
7
7
8
- T = TypeVar ("T" , bound = "BaseExceptionGroup" )
9
- EBase = TypeVar ("EBase" , bound = BaseException )
10
- E = TypeVar ("E" , bound = Exception )
11
- _SplitCondition = Union [
12
- Type [EBase ],
13
- Tuple [Type [EBase ], ...],
14
- Callable [[EBase ], bool ],
15
- ]
8
+ if TYPE_CHECKING :
9
+ from typing import Self
10
+
11
+ _BaseExceptionT_co = TypeVar ("_BaseExceptionT_co" , bound = BaseException , covariant = True )
12
+ _BaseExceptionT = TypeVar ("_BaseExceptionT" , bound = BaseException )
13
+ _ExceptionT_co = TypeVar ("_ExceptionT_co" , bound = Exception , covariant = True )
14
+ _ExceptionT = TypeVar ("_ExceptionT" , bound = Exception )
16
15
17
16
18
17
def check_direct_subclass (
@@ -25,7 +24,11 @@ def check_direct_subclass(
25
24
return False
26
25
27
26
28
- def get_condition_filter (condition : _SplitCondition ) -> Callable [[BaseException ], bool ]:
27
+ def get_condition_filter (
28
+ condition : type [_BaseExceptionT ]
29
+ | tuple [type [_BaseExceptionT ], ...]
30
+ | Callable [[_BaseExceptionT_co ], bool ]
31
+ ) -> Callable [[_BaseExceptionT_co ], bool ]:
29
32
if isclass (condition ) and issubclass (
30
33
cast (Type [BaseException ], condition ), BaseException
31
34
):
@@ -34,17 +37,17 @@ def get_condition_filter(condition: _SplitCondition) -> Callable[[BaseException]
34
37
if all (isclass (x ) and issubclass (x , BaseException ) for x in condition ):
35
38
return partial (check_direct_subclass , parents = condition )
36
39
elif callable (condition ):
37
- return cast (Callable [[BaseException ], bool ], condition )
40
+ return cast (" Callable[[BaseException], bool]" , condition )
38
41
39
42
raise TypeError ("expected a function, exception type or tuple of exception types" )
40
43
41
44
42
- class BaseExceptionGroup (BaseException , Generic [EBase ]):
45
+ class BaseExceptionGroup (BaseException , Generic [_BaseExceptionT_co ]):
43
46
"""A combination of multiple unrelated exceptions."""
44
47
45
48
def __new__ (
46
- cls , __message : str , __exceptions : Sequence [EBase ]
47
- ) -> BaseExceptionGroup [ EBase ] | ExceptionGroup [ E ] :
49
+ cls , __message : str , __exceptions : Sequence [_BaseExceptionT_co ]
50
+ ) -> Self :
48
51
if not isinstance (__message , str ):
49
52
raise TypeError (f"argument 1 must be str, not { type (__message )} " )
50
53
if not isinstance (__exceptions , Sequence ):
@@ -66,7 +69,9 @@ def __new__(
66
69
67
70
return super ().__new__ (cls , __message , __exceptions )
68
71
69
- def __init__ (self , __message : str , __exceptions : Sequence [EBase ], * args : Any ):
72
+ def __init__ (
73
+ self , __message : str , __exceptions : Sequence [_BaseExceptionT_co ], * args : Any
74
+ ):
70
75
super ().__init__ (__message , __exceptions , * args )
71
76
self ._message = __message
72
77
self ._exceptions = __exceptions
@@ -87,10 +92,29 @@ def message(self) -> str:
87
92
return self ._message
88
93
89
94
@property
90
- def exceptions (self ) -> tuple [EBase , ...]:
95
+ def exceptions (
96
+ self ,
97
+ ) -> tuple [_BaseExceptionT_co | BaseExceptionGroup [_BaseExceptionT_co ], ...]:
91
98
return tuple (self ._exceptions )
92
99
93
- def subgroup (self : T , __condition : _SplitCondition [EBase ]) -> T | None :
100
+ @overload
101
+ def subgroup (
102
+ self , __condition : type [_BaseExceptionT ] | tuple [type [_BaseExceptionT ], ...]
103
+ ) -> BaseExceptionGroup [_BaseExceptionT ] | None :
104
+ ...
105
+
106
+ @overload
107
+ def subgroup (
108
+ self : Self , __condition : Callable [[_BaseExceptionT_co ], bool ]
109
+ ) -> Self | None :
110
+ ...
111
+
112
+ def subgroup (
113
+ self : Self ,
114
+ __condition : type [_BaseExceptionT ]
115
+ | tuple [type [_BaseExceptionT ], ...]
116
+ | Callable [[_BaseExceptionT_co ], bool ],
117
+ ) -> BaseExceptionGroup [_BaseExceptionT ] | Self | None :
94
118
condition = get_condition_filter (__condition )
95
119
modified = False
96
120
if condition (self ):
@@ -99,7 +123,7 @@ def subgroup(self: T, __condition: _SplitCondition[EBase]) -> T | None:
99
123
exceptions : list [BaseException ] = []
100
124
for exc in self .exceptions :
101
125
if isinstance (exc , BaseExceptionGroup ):
102
- subgroup = exc .subgroup (condition )
126
+ subgroup = exc .subgroup (__condition )
103
127
if subgroup is not None :
104
128
exceptions .append (subgroup )
105
129
@@ -121,9 +145,27 @@ def subgroup(self: T, __condition: _SplitCondition[EBase]) -> T | None:
121
145
else :
122
146
return None
123
147
148
+ @overload
149
+ def split (
150
+ self : Self ,
151
+ __condition : type [_BaseExceptionT ] | tuple [type [_BaseExceptionT ], ...],
152
+ ) -> tuple [BaseExceptionGroup [_BaseExceptionT ] | None , Self | None ]:
153
+ ...
154
+
155
+ @overload
124
156
def split (
125
- self : T , __condition : _SplitCondition [EBase ]
126
- ) -> tuple [T | None , T | None ]:
157
+ self : Self , __condition : Callable [[_BaseExceptionT_co ], bool ]
158
+ ) -> tuple [Self | None , Self | None ]:
159
+ ...
160
+
161
+ def split (
162
+ self : Self ,
163
+ __condition : type [_BaseExceptionT ]
164
+ | tuple [type [_BaseExceptionT ], ...]
165
+ | Callable [[_BaseExceptionT_co ], bool ],
166
+ ) -> tuple [BaseExceptionGroup [_BaseExceptionT ] | None , Self | None ] | tuple [
167
+ Self | None , Self | None
168
+ ]:
127
169
condition = get_condition_filter (__condition )
128
170
if condition (self ):
129
171
return self , None
@@ -143,14 +185,14 @@ def split(
143
185
else :
144
186
nonmatching_exceptions .append (exc )
145
187
146
- matching_group : T | None = None
188
+ matching_group : Self | None = None
147
189
if matching_exceptions :
148
190
matching_group = self .derive (matching_exceptions )
149
191
matching_group .__cause__ = self .__cause__
150
192
matching_group .__context__ = self .__context__
151
193
matching_group .__traceback__ = self .__traceback__
152
194
153
- nonmatching_group : T | None = None
195
+ nonmatching_group : Self | None = None
154
196
if nonmatching_exceptions :
155
197
nonmatching_group = self .derive (nonmatching_exceptions )
156
198
nonmatching_group .__cause__ = self .__cause__
@@ -159,11 +201,12 @@ def split(
159
201
160
202
return matching_group , nonmatching_group
161
203
162
- def derive (self : T , __excs : Sequence [EBase ]) -> T :
204
+ def derive (self : Self , __excs : Sequence [_BaseExceptionT_co ]) -> Self :
163
205
eg = BaseExceptionGroup (self .message , __excs )
164
206
if hasattr (self , "__notes__" ):
165
207
# Create a new list so that add_note() only affects one exceptiongroup
166
208
eg .__notes__ = list (self .__notes__ )
209
+
167
210
return eg
168
211
169
212
def __str__ (self ) -> str :
@@ -174,12 +217,64 @@ def __repr__(self) -> str:
174
217
return f"{ self .__class__ .__name__ } ({ self .message !r} , { self ._exceptions !r} )"
175
218
176
219
177
- class ExceptionGroup (BaseExceptionGroup [E ], Exception , Generic [E ]):
178
- def __new__ (cls , __message : str , __exceptions : Sequence [E ]) -> ExceptionGroup [E ]:
179
- instance : ExceptionGroup [E ] = super ().__new__ (cls , __message , __exceptions )
220
+ class ExceptionGroup (BaseExceptionGroup [_ExceptionT_co ], Exception ):
221
+ def __new__ (cls , __message : str , __exceptions : Sequence [_ExceptionT_co ]) -> Self :
222
+ instance : ExceptionGroup [_ExceptionT_co ] = super ().__new__ (
223
+ cls , __message , __exceptions
224
+ )
180
225
if cls is ExceptionGroup :
181
226
for exc in __exceptions :
182
227
if not isinstance (exc , Exception ):
183
228
raise TypeError ("Cannot nest BaseExceptions in an ExceptionGroup" )
184
229
185
230
return instance
231
+
232
+ if TYPE_CHECKING :
233
+
234
+ @property
235
+ def exceptions (
236
+ self ,
237
+ ) -> tuple [_ExceptionT_co | ExceptionGroup [_ExceptionT_co ], ...]:
238
+ ...
239
+
240
+ @overload # type: ignore[override]
241
+ def subgroup (
242
+ self , __condition : type [_ExceptionT ] | tuple [type [_ExceptionT ], ...]
243
+ ) -> ExceptionGroup [_ExceptionT ] | None :
244
+ ...
245
+
246
+ @overload
247
+ def subgroup (
248
+ self : Self , __condition : Callable [[_ExceptionT_co ], bool ]
249
+ ) -> Self | None :
250
+ ...
251
+
252
+ def subgroup (
253
+ self : Self ,
254
+ __condition : type [_ExceptionT ]
255
+ | tuple [type [_ExceptionT ], ...]
256
+ | Callable [[_ExceptionT_co ], bool ],
257
+ ) -> ExceptionGroup [_ExceptionT ] | Self | None :
258
+ return super ().subgroup (__condition )
259
+
260
+ @overload # type: ignore[override]
261
+ def split (
262
+ self : Self , __condition : type [_ExceptionT ] | tuple [type [_ExceptionT ], ...]
263
+ ) -> tuple [ExceptionGroup [_ExceptionT ] | None , Self | None ]:
264
+ ...
265
+
266
+ @overload
267
+ def split (
268
+ self : Self , __condition : Callable [[_ExceptionT_co ], bool ]
269
+ ) -> tuple [Self | None , Self | None ]:
270
+ ...
271
+
272
+ def split (
273
+ self : Self ,
274
+ __condition : type [_ExceptionT ]
275
+ | tuple [type [_ExceptionT ], ...]
276
+ | Callable [[_ExceptionT_co ], bool ],
277
+ ) -> tuple [ExceptionGroup [_ExceptionT ] | None , Self | None ] | tuple [
278
+ Self | None , Self | None
279
+ ]:
280
+ return super ().split (__condition )
0 commit comments