Skip to content

Commit cbac14e

Browse files
authored
Updated type annotations to match those in typeshed (#40)
1 parent 2167639 commit cbac14e

File tree

2 files changed

+126
-27
lines changed

2 files changed

+126
-27
lines changed

CHANGES.rst

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ Version history
33

44
This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
55

6+
**UNRELEASED**
7+
8+
- Updated type annotations to match the ones in ``typeshed``
9+
610
**1.0.1**
711

812
- Fixed formatted traceback missing exceptions beyond 2 nesting levels of

src/exceptiongroup/_exceptions.py

+122-27
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
from __future__ import annotations
22

3-
from collections.abc import Sequence
3+
from collections.abc import Callable, Sequence
44
from functools import partial
55
from inspect import getmro, isclass
6-
from typing import Any, Callable, Generic, Tuple, Type, TypeVar, Union, cast
6+
from typing import TYPE_CHECKING, Any, Generic, Type, TypeVar, cast, overload
77

8-
T = TypeVar("T", bound="BaseExceptionGroup")
9-
EBase = TypeVar("EBase", bound=BaseException)
10-
E = TypeVar("E", bound=Exception)
11-
_SplitCondition = Union[
12-
Type[EBase],
13-
Tuple[Type[EBase], ...],
14-
Callable[[EBase], bool],
15-
]
8+
if TYPE_CHECKING:
9+
from typing import Self
10+
11+
_BaseExceptionT_co = TypeVar("_BaseExceptionT_co", bound=BaseException, covariant=True)
12+
_BaseExceptionT = TypeVar("_BaseExceptionT", bound=BaseException)
13+
_ExceptionT_co = TypeVar("_ExceptionT_co", bound=Exception, covariant=True)
14+
_ExceptionT = TypeVar("_ExceptionT", bound=Exception)
1615

1716

1817
def check_direct_subclass(
@@ -25,7 +24,11 @@ def check_direct_subclass(
2524
return False
2625

2726

28-
def get_condition_filter(condition: _SplitCondition) -> Callable[[BaseException], bool]:
27+
def get_condition_filter(
28+
condition: type[_BaseExceptionT]
29+
| tuple[type[_BaseExceptionT], ...]
30+
| Callable[[_BaseExceptionT_co], bool]
31+
) -> Callable[[_BaseExceptionT_co], bool]:
2932
if isclass(condition) and issubclass(
3033
cast(Type[BaseException], condition), BaseException
3134
):
@@ -34,17 +37,17 @@ def get_condition_filter(condition: _SplitCondition) -> Callable[[BaseException]
3437
if all(isclass(x) and issubclass(x, BaseException) for x in condition):
3538
return partial(check_direct_subclass, parents=condition)
3639
elif callable(condition):
37-
return cast(Callable[[BaseException], bool], condition)
40+
return cast("Callable[[BaseException], bool]", condition)
3841

3942
raise TypeError("expected a function, exception type or tuple of exception types")
4043

4144

42-
class BaseExceptionGroup(BaseException, Generic[EBase]):
45+
class BaseExceptionGroup(BaseException, Generic[_BaseExceptionT_co]):
4346
"""A combination of multiple unrelated exceptions."""
4447

4548
def __new__(
46-
cls, __message: str, __exceptions: Sequence[EBase]
47-
) -> BaseExceptionGroup[EBase] | ExceptionGroup[E]:
49+
cls, __message: str, __exceptions: Sequence[_BaseExceptionT_co]
50+
) -> Self:
4851
if not isinstance(__message, str):
4952
raise TypeError(f"argument 1 must be str, not {type(__message)}")
5053
if not isinstance(__exceptions, Sequence):
@@ -66,7 +69,9 @@ def __new__(
6669

6770
return super().__new__(cls, __message, __exceptions)
6871

69-
def __init__(self, __message: str, __exceptions: Sequence[EBase], *args: Any):
72+
def __init__(
73+
self, __message: str, __exceptions: Sequence[_BaseExceptionT_co], *args: Any
74+
):
7075
super().__init__(__message, __exceptions, *args)
7176
self._message = __message
7277
self._exceptions = __exceptions
@@ -87,10 +92,29 @@ def message(self) -> str:
8792
return self._message
8893

8994
@property
90-
def exceptions(self) -> tuple[EBase, ...]:
95+
def exceptions(
96+
self,
97+
) -> tuple[_BaseExceptionT_co | BaseExceptionGroup[_BaseExceptionT_co], ...]:
9198
return tuple(self._exceptions)
9299

93-
def subgroup(self: T, __condition: _SplitCondition[EBase]) -> T | None:
100+
@overload
101+
def subgroup(
102+
self, __condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...]
103+
) -> BaseExceptionGroup[_BaseExceptionT] | None:
104+
...
105+
106+
@overload
107+
def subgroup(
108+
self: Self, __condition: Callable[[_BaseExceptionT_co], bool]
109+
) -> Self | None:
110+
...
111+
112+
def subgroup(
113+
self: Self,
114+
__condition: type[_BaseExceptionT]
115+
| tuple[type[_BaseExceptionT], ...]
116+
| Callable[[_BaseExceptionT_co], bool],
117+
) -> BaseExceptionGroup[_BaseExceptionT] | Self | None:
94118
condition = get_condition_filter(__condition)
95119
modified = False
96120
if condition(self):
@@ -99,7 +123,7 @@ def subgroup(self: T, __condition: _SplitCondition[EBase]) -> T | None:
99123
exceptions: list[BaseException] = []
100124
for exc in self.exceptions:
101125
if isinstance(exc, BaseExceptionGroup):
102-
subgroup = exc.subgroup(condition)
126+
subgroup = exc.subgroup(__condition)
103127
if subgroup is not None:
104128
exceptions.append(subgroup)
105129

@@ -121,9 +145,27 @@ def subgroup(self: T, __condition: _SplitCondition[EBase]) -> T | None:
121145
else:
122146
return None
123147

148+
@overload
149+
def split(
150+
self: Self,
151+
__condition: type[_BaseExceptionT] | tuple[type[_BaseExceptionT], ...],
152+
) -> tuple[BaseExceptionGroup[_BaseExceptionT] | None, Self | None]:
153+
...
154+
155+
@overload
124156
def split(
125-
self: T, __condition: _SplitCondition[EBase]
126-
) -> tuple[T | None, T | None]:
157+
self: Self, __condition: Callable[[_BaseExceptionT_co], bool]
158+
) -> tuple[Self | None, Self | None]:
159+
...
160+
161+
def split(
162+
self: Self,
163+
__condition: type[_BaseExceptionT]
164+
| tuple[type[_BaseExceptionT], ...]
165+
| Callable[[_BaseExceptionT_co], bool],
166+
) -> tuple[BaseExceptionGroup[_BaseExceptionT] | None, Self | None] | tuple[
167+
Self | None, Self | None
168+
]:
127169
condition = get_condition_filter(__condition)
128170
if condition(self):
129171
return self, None
@@ -143,14 +185,14 @@ def split(
143185
else:
144186
nonmatching_exceptions.append(exc)
145187

146-
matching_group: T | None = None
188+
matching_group: Self | None = None
147189
if matching_exceptions:
148190
matching_group = self.derive(matching_exceptions)
149191
matching_group.__cause__ = self.__cause__
150192
matching_group.__context__ = self.__context__
151193
matching_group.__traceback__ = self.__traceback__
152194

153-
nonmatching_group: T | None = None
195+
nonmatching_group: Self | None = None
154196
if nonmatching_exceptions:
155197
nonmatching_group = self.derive(nonmatching_exceptions)
156198
nonmatching_group.__cause__ = self.__cause__
@@ -159,11 +201,12 @@ def split(
159201

160202
return matching_group, nonmatching_group
161203

162-
def derive(self: T, __excs: Sequence[EBase]) -> T:
204+
def derive(self: Self, __excs: Sequence[_BaseExceptionT_co]) -> Self:
163205
eg = BaseExceptionGroup(self.message, __excs)
164206
if hasattr(self, "__notes__"):
165207
# Create a new list so that add_note() only affects one exceptiongroup
166208
eg.__notes__ = list(self.__notes__)
209+
167210
return eg
168211

169212
def __str__(self) -> str:
@@ -174,12 +217,64 @@ def __repr__(self) -> str:
174217
return f"{self.__class__.__name__}({self.message!r}, {self._exceptions!r})"
175218

176219

177-
class ExceptionGroup(BaseExceptionGroup[E], Exception, Generic[E]):
178-
def __new__(cls, __message: str, __exceptions: Sequence[E]) -> ExceptionGroup[E]:
179-
instance: ExceptionGroup[E] = super().__new__(cls, __message, __exceptions)
220+
class ExceptionGroup(BaseExceptionGroup[_ExceptionT_co], Exception):
221+
def __new__(cls, __message: str, __exceptions: Sequence[_ExceptionT_co]) -> Self:
222+
instance: ExceptionGroup[_ExceptionT_co] = super().__new__(
223+
cls, __message, __exceptions
224+
)
180225
if cls is ExceptionGroup:
181226
for exc in __exceptions:
182227
if not isinstance(exc, Exception):
183228
raise TypeError("Cannot nest BaseExceptions in an ExceptionGroup")
184229

185230
return instance
231+
232+
if TYPE_CHECKING:
233+
234+
@property
235+
def exceptions(
236+
self,
237+
) -> tuple[_ExceptionT_co | ExceptionGroup[_ExceptionT_co], ...]:
238+
...
239+
240+
@overload # type: ignore[override]
241+
def subgroup(
242+
self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
243+
) -> ExceptionGroup[_ExceptionT] | None:
244+
...
245+
246+
@overload
247+
def subgroup(
248+
self: Self, __condition: Callable[[_ExceptionT_co], bool]
249+
) -> Self | None:
250+
...
251+
252+
def subgroup(
253+
self: Self,
254+
__condition: type[_ExceptionT]
255+
| tuple[type[_ExceptionT], ...]
256+
| Callable[[_ExceptionT_co], bool],
257+
) -> ExceptionGroup[_ExceptionT] | Self | None:
258+
return super().subgroup(__condition)
259+
260+
@overload # type: ignore[override]
261+
def split(
262+
self: Self, __condition: type[_ExceptionT] | tuple[type[_ExceptionT], ...]
263+
) -> tuple[ExceptionGroup[_ExceptionT] | None, Self | None]:
264+
...
265+
266+
@overload
267+
def split(
268+
self: Self, __condition: Callable[[_ExceptionT_co], bool]
269+
) -> tuple[Self | None, Self | None]:
270+
...
271+
272+
def split(
273+
self: Self,
274+
__condition: type[_ExceptionT]
275+
| tuple[type[_ExceptionT], ...]
276+
| Callable[[_ExceptionT_co], bool],
277+
) -> tuple[ExceptionGroup[_ExceptionT] | None, Self | None] | tuple[
278+
Self | None, Self | None
279+
]:
280+
return super().split(__condition)

0 commit comments

Comments
 (0)