Skip to content

Commit ef61b87

Browse files
authored
Merge pull request #127 from chanchiem/serverless
Add Serverless Framework Support
2 parents 237f8c7 + 293f76f commit ef61b87

File tree

6 files changed

+133
-35
lines changed

6 files changed

+133
-35
lines changed

aws_xray_sdk/core/models/entity.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
_common_invalid_name_characters = '?;*()!$~^<>'
1919
_valid_annotation_key_characters = string.ascii_letters + string.digits + '_'
2020

21+
ORIGIN_TRACE_HEADER_ATTR_KEY = '_origin_trace_header'
22+
2123

2224
class Entity(object):
2325
"""
@@ -228,6 +230,20 @@ def add_exception(self, exception, stack, remote=False):
228230
self.cause['exceptions'] = exceptions
229231
self.cause['working_directory'] = os.getcwd()
230232

233+
def save_origin_trace_header(self, trace_header):
234+
"""
235+
Temporarily store additional data fields in trace header
236+
to the entity for later propagation. The data will be
237+
cleaned up upon serialization.
238+
"""
239+
setattr(self, ORIGIN_TRACE_HEADER_ATTR_KEY, trace_header)
240+
241+
def get_origin_trace_header(self):
242+
"""
243+
Retrieve saved trace header data.
244+
"""
245+
return getattr(self, ORIGIN_TRACE_HEADER_ATTR_KEY, None)
246+
231247
def serialize(self):
232248
"""
233249
Serialize to JSON document that can be accepted by the
@@ -258,6 +274,7 @@ def _delete_empty_properties(self, properties):
258274
del properties['annotations']
259275
if not self.metadata:
260276
del properties['metadata']
277+
properties.pop(ORIGIN_TRACE_HEADER_ATTR_KEY, None)
261278

262279
del properties['sampled']
263280

aws_xray_sdk/core/models/segment.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -155,20 +155,6 @@ def set_rule_name(self, rule_name):
155155
self.aws['xray'] = {}
156156
self.aws['xray']['sampling_rule_name'] = rule_name
157157

158-
def save_origin_trace_header(self, trace_header):
159-
"""
160-
Temporarily store additional data fields in trace header
161-
to the segment for later propagation. The data will be
162-
cleaned up upon serilaization.
163-
"""
164-
setattr(self, ORIGIN_TRACE_HEADER_ATTR_KEY, trace_header)
165-
166-
def get_origin_trace_header(self):
167-
"""
168-
Retrieve saved trace header data.
169-
"""
170-
return getattr(self, ORIGIN_TRACE_HEADER_ATTR_KEY, None)
171-
172158
def __getstate__(self):
173159
"""
174160
Used by jsonpikle to remove unwanted fields.
@@ -179,5 +165,4 @@ def __getstate__(self):
179165
del properties['user']
180166
del properties['ref_counter']
181167
del properties['_subsegments_counter']
182-
properties.pop(ORIGIN_TRACE_HEADER_ATTR_KEY, None)
183168
return properties

aws_xray_sdk/ext/django/middleware.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from aws_xray_sdk.core.utils import stacktrace
66
from aws_xray_sdk.ext.util import calculate_sampling_decision, \
77
calculate_segment_name, construct_xray_header, prepare_response_header
8+
from aws_xray_sdk.core.lambda_launcher import check_in_lambda
89

910

1011
log = logging.getLogger(__name__)
@@ -24,6 +25,10 @@ class XRayMiddleware(object):
2425
def __init__(self, get_response):
2526

2627
self.get_response = get_response
28+
self.in_lambda = False
29+
30+
if check_in_lambda():
31+
self.in_lambda = True
2732

2833
# hooks for django version >= 1.10
2934
def __call__(self, request):
@@ -46,12 +51,15 @@ def __call__(self, request):
4651
sampling_req=sampling_req,
4752
)
4853

49-
segment = xray_recorder.begin_segment(
50-
name=name,
51-
traceid=xray_header.root,
52-
parent_id=xray_header.parent,
53-
sampling=sampling_decision,
54-
)
54+
if self.in_lambda:
55+
segment = xray_recorder.begin_subsegment(name)
56+
else:
57+
segment = xray_recorder.begin_segment(
58+
name=name,
59+
traceid=xray_header.root,
60+
parent_id=xray_header.parent,
61+
sampling=sampling_decision,
62+
)
5563

5664
segment.save_origin_trace_header(xray_header)
5765
segment.put_http_meta(http.URL, request.build_absolute_uri())
@@ -75,7 +83,10 @@ def __call__(self, request):
7583
segment.put_http_meta(http.CONTENT_LENGTH, length)
7684
response[http.XRAY_HEADER] = prepare_response_header(xray_header, segment)
7785

78-
xray_recorder.end_segment()
86+
if self.in_lambda:
87+
xray_recorder.end_subsegment()
88+
else:
89+
xray_recorder.end_segment()
7990

8091
return response
8192

aws_xray_sdk/ext/flask/middleware.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from aws_xray_sdk.core.utils import stacktrace
66
from aws_xray_sdk.ext.util import calculate_sampling_decision, \
77
calculate_segment_name, construct_xray_header, prepare_response_header
8+
from aws_xray_sdk.core.lambda_launcher import check_in_lambda
89

910

1011
class XRayMiddleware(object):
@@ -17,6 +18,10 @@ def __init__(self, app, recorder):
1718
self.app.before_request(self._before_request)
1819
self.app.after_request(self._after_request)
1920
self.app.teardown_request(self._handle_exception)
21+
self.in_lambda = False
22+
23+
if check_in_lambda():
24+
self.in_lambda = True
2025

2126
_patch_render(recorder)
2227

@@ -39,12 +44,15 @@ def _before_request(self):
3944
sampling_req=sampling_req,
4045
)
4146

42-
segment = self._recorder.begin_segment(
43-
name=name,
44-
traceid=xray_header.root,
45-
parent_id=xray_header.parent,
46-
sampling=sampling_decision,
47-
)
47+
if self.in_lambda:
48+
segment = self._recorder.begin_subsegment(name)
49+
else:
50+
segment = self._recorder.begin_segment(
51+
name=name,
52+
traceid=xray_header.root,
53+
parent_id=xray_header.parent,
54+
sampling=sampling_decision,
55+
)
4856

4957
segment.save_origin_trace_header(xray_header)
5058
segment.put_http_meta(http.URL, req.base_url)
@@ -59,7 +67,10 @@ def _before_request(self):
5967
segment.put_http_meta(http.CLIENT_IP, req.remote_addr)
6068

6169
def _after_request(self, response):
62-
segment = self._recorder.current_segment()
70+
if self.in_lambda:
71+
segment = self._recorder.current_subsegment()
72+
else:
73+
segment = self._recorder.current_segment()
6374
segment.put_http_meta(http.STATUS, response.status_code)
6475

6576
origin_header = segment.get_origin_trace_header()
@@ -70,15 +81,21 @@ def _after_request(self, response):
7081
if cont_len:
7182
segment.put_http_meta(http.CONTENT_LENGTH, int(cont_len))
7283

73-
self._recorder.end_segment()
84+
if self.in_lambda:
85+
self._recorder.end_subsegment()
86+
else:
87+
self._recorder.end_segment()
7488
return response
7589

7690
def _handle_exception(self, exception):
7791
if not exception:
7892
return
7993
segment = None
8094
try:
81-
segment = self._recorder.current_segment()
95+
if self.in_lambda:
96+
segment = self._recorder.current_subsegment()
97+
else:
98+
segment = self._recorder.current_segment()
8299
except Exception:
83100
pass
84101
if not segment:
@@ -87,7 +104,10 @@ def _handle_exception(self, exception):
87104
segment.put_http_meta(http.STATUS, 500)
88105
stack = stacktrace.get_stacktrace(limit=self._recorder._max_trace_back)
89106
segment.add_exception(exception, stack)
90-
self._recorder.end_segment()
107+
if self.in_lambda:
108+
self._recorder.end_subsegment()
109+
else:
110+
self._recorder.end_segment()
91111

92112

93113
def _patch_render(recorder):

tests/ext/django/test_middleware.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
from django.core.urlresolvers import reverse
44
from django.test import TestCase
55

6-
from aws_xray_sdk.core import xray_recorder
6+
from aws_xray_sdk.core import xray_recorder, lambda_launcher
77
from aws_xray_sdk.core.context import Context
8-
from aws_xray_sdk.core.models import http
8+
from aws_xray_sdk.core.models import http, facade_segment
9+
from tests.util import get_new_stubbed_recorder
10+
import os
911

1012

1113
class XRayTestCase(TestCase):
@@ -111,3 +113,22 @@ def test_disabled_sdk(self):
111113
self.client.get(url)
112114
segment = xray_recorder.emitter.pop()
113115
assert not segment
116+
117+
def test_lambda_serverless(self):
118+
TRACE_ID = '1-5759e988-bd862e3fe1be46a994272793'
119+
PARENT_ID = '53995c3f42cd8ad8'
120+
HEADER_VAR = "Root=%s;Parent=%s;Sampled=1" % (TRACE_ID, PARENT_ID)
121+
122+
os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY] = HEADER_VAR
123+
lambda_context = lambda_launcher.LambdaContext()
124+
125+
new_recorder = get_new_stubbed_recorder()
126+
new_recorder.configure(service='test', sampling=False, context=lambda_context)
127+
subsegment = new_recorder.begin_subsegment("subsegment")
128+
assert type(subsegment.parent_segment) == facade_segment.FacadeSegment
129+
new_recorder.end_subsegment()
130+
131+
url = reverse('200ok')
132+
self.client.get(url)
133+
segment = new_recorder.emitter.pop()
134+
assert not segment

tests/ext/flask/test_flask.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from aws_xray_sdk import global_sdk_config
55
from aws_xray_sdk.ext.flask.middleware import XRayMiddleware
66
from aws_xray_sdk.core.context import Context
7-
from aws_xray_sdk.core.models import http
7+
from aws_xray_sdk.core import lambda_launcher
8+
from aws_xray_sdk.core.models import http, facade_segment
89
from tests.util import get_new_stubbed_recorder
10+
import os
911

1012

1113
# define a flask app for testing purpose
@@ -153,3 +155,45 @@ def test_disabled_sdk():
153155
app.get(path)
154156
segment = recorder.emitter.pop()
155157
assert not segment
158+
159+
160+
def test_lambda_serverless():
161+
TRACE_ID = '1-5759e988-bd862e3fe1be46a994272793'
162+
PARENT_ID = '53995c3f42cd8ad8'
163+
HEADER_VAR = "Root=%s;Parent=%s;Sampled=1" % (TRACE_ID, PARENT_ID)
164+
165+
os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY] = HEADER_VAR
166+
lambda_context = lambda_launcher.LambdaContext()
167+
168+
new_recorder = get_new_stubbed_recorder()
169+
new_recorder.configure(service='test', sampling=False, context=lambda_context)
170+
new_app = Flask(__name__)
171+
172+
@new_app.route('/subsegment')
173+
def subsegment():
174+
# Test in between request and make sure Serverless creates a subsegment instead of a segment.
175+
# Ensure that the parent segment is a facade segment.
176+
assert new_recorder.current_subsegment()
177+
assert type(new_recorder.current_segment()) == facade_segment.FacadeSegment
178+
return 'ok'
179+
180+
@new_app.route('/trace_header')
181+
def trace_header():
182+
# Ensure trace header is preserved.
183+
subsegment = new_recorder.current_subsegment()
184+
header = subsegment.get_origin_trace_header()
185+
assert header.data['k1'] == 'v1'
186+
return 'ok'
187+
188+
middleware = XRayMiddleware(new_app, new_recorder)
189+
middleware.in_lambda = True
190+
191+
app_client = new_app.test_client()
192+
193+
path = '/subsegment'
194+
app_client.get(path)
195+
segment = recorder.emitter.pop()
196+
assert not segment # Segment should be none because it's created and ended by the middleware
197+
198+
path2 = '/trace_header'
199+
app_client.get(path2, headers={http.XRAY_HEADER: 'k1=v1'})

0 commit comments

Comments
 (0)