diff --git a/aws_xray_sdk/core/models/entity.py b/aws_xray_sdk/core/models/entity.py index 8b8bdf4f..07641062 100644 --- a/aws_xray_sdk/core/models/entity.py +++ b/aws_xray_sdk/core/models/entity.py @@ -211,7 +211,7 @@ def add_exception(self, exception, stack, remote=False): """ Add an exception to trace entities. - :param Exception exception: the catched exception. + :param Exception exception: the caught exception. :param list stack: the output from python built-in `traceback.extract_stack()`. :param bool remote: If False it means it's a client error @@ -224,7 +224,16 @@ def add_exception(self, exception, stack, remote=False): setattr(self, 'cause', getattr(exception, '_cause_id')) return - exceptions = [] + if not isinstance(self.cause, dict): + log.warning("The current cause object is not a dict but an id: {}. Resetting the cause and recording the " + "current exception".format(self.cause)) + self.cause = {} + + if 'exceptions' in self.cause: + exceptions = self.cause['exceptions'] + else: + exceptions = [] + exceptions.append(Throwable(exception, stack, remote)) self.cause['exceptions'] = exceptions diff --git a/tests/test_trace_entities.py b/tests/test_trace_entities.py index 01754d8e..e42cee0c 100644 --- a/tests/test_trace_entities.py +++ b/tests/test_trace_entities.py @@ -1,5 +1,7 @@ # -*- coding: iso-8859-15 -*- + import pytest +import sys from aws_xray_sdk.core.models.segment import Segment from aws_xray_sdk.core.models.subsegment import Subsegment @@ -194,3 +196,70 @@ def test_missing_parent_segment(): with pytest.raises(SegmentNotFoundException): Subsegment('name', 'local', None) + + +def test_add_exception(): + segment = Segment('seg') + exception = Exception("testException") + stack = [['path', 'line', 'label']] + segment.add_exception(exception=exception, stack=stack) + segment.close() + + cause = segment.cause + assert 'exceptions' in cause + exceptions = cause['exceptions'] + assert len(exceptions) == 1 + assert 'working_directory' in cause + exception = exceptions[0] + assert 'testException' == exception.message + expected_stack = [{'path': 'path', 'line': 'line', 'label': 'label'}] + assert expected_stack == exception.stack + + +def test_add_exception_referencing(): + segment = Segment('seg') + subseg = Subsegment('subseg', 'remote', segment) + exception = Exception("testException") + stack = [['path', 'line', 'label']] + subseg.add_exception(exception=exception, stack=stack) + segment.add_exception(exception=exception, stack=stack) + subseg.close() + segment.close() + + seg_cause = segment.cause + subseg_cause = subseg.cause + + assert isinstance(subseg_cause, dict) + if sys.version_info.major == 2: + assert isinstance(seg_cause, basestring) + else: + assert isinstance(seg_cause, str) + assert seg_cause == subseg_cause['exceptions'][0].id + + +def test_add_exception_cause_resetting(): + segment = Segment('seg') + subseg = Subsegment('subseg', 'remote', segment) + exception = Exception("testException") + stack = [['path', 'line', 'label']] + subseg.add_exception(exception=exception, stack=stack) + segment.add_exception(exception=exception, stack=stack) + + segment.add_exception(exception=Exception("newException"), stack=stack) + subseg.close() + segment.close() + + seg_cause = segment.cause + assert isinstance(seg_cause, dict) + assert 'newException' == seg_cause['exceptions'][0].message + + +def test_add_exception_appending_exceptions(): + segment = Segment('seg') + stack = [['path', 'line', 'label']] + segment.add_exception(exception=Exception("testException"), stack=stack) + segment.add_exception(exception=Exception("newException"), stack=stack) + segment.close() + + assert isinstance(segment.cause, dict) + assert len(segment.cause['exceptions']) == 2