Skip to content

Commit 63b5ea2

Browse files
authored
Merge pull request #3 from terrycain/aiohttp_middleware
AioHttp server middleware
2 parents a00b8b5 + 9672b13 commit 63b5ea2

File tree

12 files changed

+468
-14
lines changed

12 files changed

+468
-14
lines changed

aws_xray_sdk/core/async_context.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import asyncio
2+
3+
from .context import Context as _Context
4+
5+
6+
class AsyncContext(_Context):
7+
"""
8+
Async Context for storing segments.
9+
10+
Inherits nearly everything from the main Context class.
11+
Replaces threading.local with a task based local storage class,
12+
Also overrides clear_trace_entities
13+
"""
14+
def __init__(self, *args, loop=None, use_task_factory=True, **kwargs):
15+
super(AsyncContext, self).__init__(*args, **kwargs)
16+
17+
self._loop = loop
18+
if loop is None:
19+
self._loop = asyncio.get_event_loop()
20+
21+
if use_task_factory:
22+
self._loop.set_task_factory(task_factory)
23+
24+
self._local = TaskLocalStorage(loop=loop)
25+
26+
def clear_trace_entities(self):
27+
"""
28+
Clear all trace_entities stored in the task local context.
29+
"""
30+
if self._local is not None:
31+
self._local.clear()
32+
33+
34+
class TaskLocalStorage(object):
35+
"""
36+
Simple task local storage
37+
"""
38+
def __init__(self, loop=None):
39+
if loop is None:
40+
loop = asyncio.get_event_loop()
41+
self._loop = loop
42+
43+
def __setattr__(self, name, value):
44+
if name in ('_loop',):
45+
# Set normal attributes
46+
object.__setattr__(self, name, value)
47+
48+
else:
49+
# Set task local attributes
50+
task = asyncio.Task.current_task(loop=self._loop)
51+
if task is None:
52+
return None
53+
54+
if not hasattr(task, 'context'):
55+
task.context = {}
56+
57+
task.context[name] = value
58+
59+
def __getattribute__(self, item):
60+
if item in ('_loop', 'clear'):
61+
# Return references to local objects
62+
return object.__getattribute__(self, item)
63+
64+
task = asyncio.Task.current_task(loop=self._loop)
65+
if task is None:
66+
return None
67+
68+
if hasattr(task, 'context') and item in task.context:
69+
return task.context[item]
70+
71+
raise AttributeError('Task context does not have attribute {0}'.format(item))
72+
73+
def clear(self):
74+
# If were in a task, clear the context dictionary
75+
task = asyncio.Task.current_task(loop=self._loop)
76+
if task is not None and hasattr(task, 'context'):
77+
task.context.clear()
78+
79+
80+
def task_factory(loop, coro):
81+
"""
82+
Task factory function
83+
84+
Fuction closely mirrors the logic inside of
85+
asyncio.BaseEventLoop.create_task. Then if there is a current
86+
task and the current task has a context then share that context
87+
with the new task
88+
"""
89+
task = asyncio.Task(coro, loop=loop)
90+
if task._source_traceback: # flake8: noqa
91+
del task._source_traceback[-1] # flake8: noqa
92+
93+
# Share context with new task if possible
94+
current_task = asyncio.Task.current_task(loop=loop)
95+
if current_task is not None and hasattr(current_task, 'context'):
96+
setattr(task, 'context', current_task.context)
97+
98+
return task

aws_xray_sdk/core/sampling/sampling_rule.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def applies(self, service_name, method, path):
4040
the incoming request based on some of the request's parameters.
4141
Any None parameters provided will be considered an implicit match.
4242
"""
43-
return (not service_name or wildcard_match(self.service_name, service_name)) \
43+
return (not service_name or wildcard_match(self.service_name, service_name)) \
4444
and (not method or wildcard_match(self.service_name, method)) \
4545
and (not path or wildcard_match(self.path, path))
4646

@@ -89,11 +89,14 @@ def reservoir(self):
8989

9090
def _validate(self):
9191
if self.fixed_target < 0 or self.rate < 0:
92-
raise InvalidSamplingManifestError('All rules must have non-negative values for fixed_target and rate')
92+
raise InvalidSamplingManifestError('All rules must have non-negative values for '
93+
'fixed_target and rate')
9394

9495
if self._default:
9596
if self.service_name or self.method or self.path:
96-
raise InvalidSamplingManifestError('The default rule must not specify values for url_path, service_name, or http_method')
97+
raise InvalidSamplingManifestError('The default rule must not specify values for '
98+
'url_path, service_name, or http_method')
9799
else:
98100
if not self.service_name or not self.method or not self.path:
99-
raise InvalidSamplingManifestError('All non-default rules must have values for url_path, service_name, and http_method')
101+
raise InvalidSamplingManifestError('All non-default rules must have values for '
102+
'url_path, service_name, and http_method')

aws_xray_sdk/core/utils/compat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
PY2 = sys.version_info < (3,)
55

66
if PY2:
7-
annotation_value_types = (int, long, float, bool, str)
8-
string_types = basestring
7+
annotation_value_types = (int, long, float, bool, str) # noqa: F821
8+
string_types = basestring # noqa: F821
99
else:
1010
annotation_value_types = (int, float, bool, str)
1111
string_types = str

aws_xray_sdk/ext/aiohttp/__init__.py

Whitespace-only changes.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
AioHttp Middleware
3+
"""
4+
import traceback
5+
6+
from aws_xray_sdk.core import xray_recorder
7+
from aws_xray_sdk.core.models import http
8+
from aws_xray_sdk.ext.util import calculate_sampling_decision, calculate_segment_name, construct_xray_header
9+
10+
11+
async def middleware(app, handler):
12+
"""
13+
AioHttp Middleware Factory
14+
"""
15+
async def _middleware(request):
16+
"""
17+
Main middleware function, deals with all the X-Ray segment logic
18+
"""
19+
# Create X-Ray headers
20+
xray_header = construct_xray_header(request.headers)
21+
# Get name of service or generate a dynamic one from host
22+
name = calculate_segment_name(request.headers['host'].split(':', 1)[0], xray_recorder)
23+
24+
sampling_decision = calculate_sampling_decision(
25+
trace_header=xray_header,
26+
recorder=xray_recorder,
27+
service_name=request.headers['host'],
28+
method=request.method,
29+
path=request.path,
30+
)
31+
32+
# Start a segment
33+
segment = xray_recorder.begin_segment(
34+
name=name,
35+
traceid=xray_header.root,
36+
parent_id=xray_header.parent,
37+
sampling=sampling_decision,
38+
)
39+
40+
# Store request metadata in the current segment
41+
segment.put_http_meta(http.URL, request.url)
42+
segment.put_http_meta(http.METHOD, request.method)
43+
44+
if 'User-Agent' in request.headers:
45+
segment.put_http_meta(http.USER_AGENT, request.headers['User-Agent'])
46+
47+
if 'X-Forwarded-For' in request.headers:
48+
segment.put_http_meta(http.CLIENT_IP, request.headers['X-Forwarded-For'])
49+
segment.put_http_meta(http.X_FORWARDED_FOR, True)
50+
elif 'remote_addr' in request.headers:
51+
segment.put_http_meta(http.CLIENT_IP, request.headers['remote_addr'])
52+
else:
53+
segment.put_http_meta(http.CLIENT_IP, request.remote)
54+
55+
try:
56+
# Call next middleware or request handler
57+
response = await handler(request)
58+
except Exception as err:
59+
# Store exception information including the stacktrace to the segment
60+
segment = xray_recorder.current_segment()
61+
segment.put_http_meta(http.STATUS, 500)
62+
stack = traceback.extract_stack(limit=xray_recorder._max_trace_back)
63+
segment.add_exception(err, stack)
64+
xray_recorder.end_segment()
65+
raise
66+
67+
# Store response metadata into the current segment
68+
segment.put_http_meta(http.STATUS, response.status)
69+
70+
if 'Content-Length' in response.headers:
71+
length = int(response.headers['Content-Length'])
72+
segment.put_http_meta(http.CONTENT_LENGTH, length)
73+
74+
# Close segment so it can be dispatched off to the daemon
75+
xray_recorder.end_segment()
76+
77+
return response
78+
return _middleware

docs/conf.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
3232
# ones.
3333
extensions = ['sphinx.ext.autodoc',
34-
'sphinx.ext.doctest',
35-
'sphinx.ext.intersphinx',
36-
'sphinx.ext.coverage']
34+
'sphinx.ext.doctest',
35+
'sphinx.ext.intersphinx',
36+
'sphinx.ext.coverage']
3737

3838
# Add any paths that contain templates here, relative to this directory.
3939
templates_path = ['_templates']
@@ -171,7 +171,5 @@
171171
]
172172

173173

174-
175-
176174
# Example configuration for intersphinx: refer to the Python standard library.
177175
intersphinx_mapping = {'https://docs.python.org/': None}

docs/frameworks.rst

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,36 @@ To generate segment based on incoming requests, you need to instantiate the X-Ra
8181
XRayMiddleware(app, xray_recorder)
8282

8383
Flask built-in template rendering will be wrapped into subsegments.
84-
You can configure the recorder, see :ref:`Configure Global Recorder <configurations>` for more details.
84+
You can configure the recorder, see :ref:`Configure Global Recorder <configurations>` for more details.
85+
86+
aiohttp Server
87+
==============
88+
89+
For X-Ray to create a segment based on an incoming request, you need register some middleware with aiohttp. As aiohttp
90+
is an asyncronous framework, X-Ray will also need to be configured with an ``AsyncContext`` compared to the default threadded
91+
version.::
92+
93+
import asyncio
94+
95+
from aiohttp import web
96+
97+
from aws_xray_sdk.ext.aiohttp.middleware import middleware
98+
from aws_xray_sdk.core.async_context import AsyncContext
99+
from aws_xray_sdk.core import xray_recorder
100+
# Configure X-Ray to use AsyncContext
101+
xray_recorder.configure(service='service_name', context=AsyncContext())
102+
103+
104+
async def handler(request):
105+
return web.Response(body='Hello World')
106+
107+
loop = asyncio.get_event_loop()
108+
# Use X-Ray SDK middleware, its crucial the X-Ray middleware comes first
109+
app = web.Application(middlewares=[middleware])
110+
app.router.add_get("/", handler)
111+
112+
web.run_app(app)
113+
114+
There are two things to note from the example above. Firstly a middleware corountine from aws-xray-sdk is provided during the creation
115+
of an aiohttp server app. Lastly the ``xray_recorder`` has also been configured with a name and an ``AsyncContext``. See
116+
:ref:`Configure Global Recorder <configurations>` for more information about configuring the ``xray_recorder``.

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
name='aws-xray-sdk',
1212
version='0.93',
1313

14-
description='The AWS X-Ray SDK for Python (the SDK) enables Python developers to record and emit information from within their applications to the AWS X-Ray service.',
14+
description='The AWS X-Ray SDK for Python (the SDK) enables Python developers to record'
15+
' and emit information from within their applications to the AWS X-Ray service.',
1516
long_description=long_description,
1617

1718
url='https://github.com/aws/aws-xray-sdk-python',

tests/ext/aiohttp/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)