Skip to content

Commit ee53e9f

Browse files
authored
BaseExceptionGroup.derive should not copy __notes__ (#112)
This makes the behaviour follow that of CPython more closely. Instead, copy `__notes__` (if present) in the *callers* of derive. The (modified) test passes now, and it passes on Python 3.11. It fails before the changes.
1 parent 2f23259 commit ee53e9f

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

src/exceptiongroup/_exceptions.py

+17-18
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,17 @@ def get_condition_filter(
4242
raise TypeError("expected a function, exception type or tuple of exception types")
4343

4444

45+
def _derive_and_copy_attributes(self, excs):
46+
eg = self.derive(excs)
47+
eg.__cause__ = self.__cause__
48+
eg.__context__ = self.__context__
49+
eg.__traceback__ = self.__traceback__
50+
if hasattr(self, "__notes__"):
51+
# Create a new list so that add_note() only affects one exceptiongroup
52+
eg.__notes__ = list(self.__notes__)
53+
return eg
54+
55+
4556
class BaseExceptionGroup(BaseException, Generic[_BaseExceptionT_co]):
4657
"""A combination of multiple unrelated exceptions."""
4758

@@ -154,10 +165,7 @@ def subgroup(
154165
if not modified:
155166
return self
156167
elif exceptions:
157-
group = self.derive(exceptions)
158-
group.__cause__ = self.__cause__
159-
group.__context__ = self.__context__
160-
group.__traceback__ = self.__traceback__
168+
group = _derive_and_copy_attributes(self, exceptions)
161169
return group
162170
else:
163171
return None
@@ -230,17 +238,13 @@ def split(
230238

231239
matching_group: _BaseExceptionGroupSelf | None = None
232240
if matching_exceptions:
233-
matching_group = self.derive(matching_exceptions)
234-
matching_group.__cause__ = self.__cause__
235-
matching_group.__context__ = self.__context__
236-
matching_group.__traceback__ = self.__traceback__
241+
matching_group = _derive_and_copy_attributes(self, matching_exceptions)
237242

238243
nonmatching_group: _BaseExceptionGroupSelf | None = None
239244
if nonmatching_exceptions:
240-
nonmatching_group = self.derive(nonmatching_exceptions)
241-
nonmatching_group.__cause__ = self.__cause__
242-
nonmatching_group.__context__ = self.__context__
243-
nonmatching_group.__traceback__ = self.__traceback__
245+
nonmatching_group = _derive_and_copy_attributes(
246+
self, nonmatching_exceptions
247+
)
244248

245249
return matching_group, nonmatching_group
246250

@@ -257,12 +261,7 @@ def derive(
257261
def derive(
258262
self, __excs: Sequence[_BaseExceptionT]
259263
) -> BaseExceptionGroup[_BaseExceptionT]:
260-
eg = BaseExceptionGroup(self.message, __excs)
261-
if hasattr(self, "__notes__"):
262-
# Create a new list so that add_note() only affects one exceptiongroup
263-
eg.__notes__ = list(self.__notes__)
264-
265-
return eg
264+
return BaseExceptionGroup(self.message, __excs)
266265

267266
def __str__(self) -> str:
268267
suffix = "" if len(self._exceptions) == 1 else "s"

tests/test_exceptions.py

+14
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,13 @@ def test_notes_is_list_of_strings_if_it_exists(self):
205205
eg.add_note(note)
206206
self.assertEqual(eg.__notes__, [note])
207207

208+
def test_derive_doesn_copy_notes(self):
209+
eg = create_simple_eg()
210+
eg.add_note("hello")
211+
assert eg.__notes__ == ["hello"]
212+
eg2 = eg.derive([ValueError()])
213+
assert not hasattr(eg2, "__notes__")
214+
208215

209216
class ExceptionGroupTestBase(unittest.TestCase):
210217
def assertMatchesTemplate(self, exc, exc_type, template):
@@ -786,6 +793,7 @@ def derive(self, excs):
786793
except ValueError as ve:
787794
raise EG("eg", [ve, nested], 42)
788795
except EG as e:
796+
e.add_note("hello")
789797
eg = e
790798

791799
self.assertMatchesTemplate(eg, EG, [ValueError(1), [TypeError(2)]])
@@ -796,29 +804,35 @@ def derive(self, excs):
796804
self.assertMatchesTemplate(rest, EG, [ValueError(1), [TypeError(2)]])
797805
self.assertEqual(rest.code, 42)
798806
self.assertEqual(rest.exceptions[1].code, 101)
807+
self.assertEqual(rest.__notes__, ["hello"])
799808

800809
# Match Everything
801810
match, rest = self.split_exception_group(eg, (ValueError, TypeError))
802811
self.assertMatchesTemplate(match, EG, [ValueError(1), [TypeError(2)]])
803812
self.assertEqual(match.code, 42)
804813
self.assertEqual(match.exceptions[1].code, 101)
814+
self.assertEqual(match.__notes__, ["hello"])
805815
self.assertIsNone(rest)
806816

807817
# Match ValueErrors
808818
match, rest = self.split_exception_group(eg, ValueError)
809819
self.assertMatchesTemplate(match, EG, [ValueError(1)])
810820
self.assertEqual(match.code, 42)
821+
self.assertEqual(match.__notes__, ["hello"])
811822
self.assertMatchesTemplate(rest, EG, [[TypeError(2)]])
812823
self.assertEqual(rest.code, 42)
813824
self.assertEqual(rest.exceptions[0].code, 101)
825+
self.assertEqual(rest.__notes__, ["hello"])
814826

815827
# Match TypeErrors
816828
match, rest = self.split_exception_group(eg, TypeError)
817829
self.assertMatchesTemplate(match, EG, [[TypeError(2)]])
818830
self.assertEqual(match.code, 42)
819831
self.assertEqual(match.exceptions[0].code, 101)
832+
self.assertEqual(match.__notes__, ["hello"])
820833
self.assertMatchesTemplate(rest, EG, [ValueError(1)])
821834
self.assertEqual(rest.code, 42)
835+
self.assertEqual(rest.__notes__, ["hello"])
822836

823837

824838
def test_repr():

0 commit comments

Comments
 (0)