diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 441274f3..ebe3496c 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,10 @@ CHANGELOG ========= +unreleased +========== +* feature: Use the official middleware pattern for Aiohttp ext. `PR29 `_. + 0.96 ==== * feature: Add support for SQLAlchemy and Flask-SQLAlcemy. `PR14 `_. diff --git a/aws_xray_sdk/ext/aiohttp/middleware.py b/aws_xray_sdk/ext/aiohttp/middleware.py index 47c7248f..74b04bdd 100644 --- a/aws_xray_sdk/ext/aiohttp/middleware.py +++ b/aws_xray_sdk/ext/aiohttp/middleware.py @@ -1,6 +1,7 @@ """ AioHttp Middleware """ +import aiohttp import traceback from aws_xray_sdk.core import xray_recorder @@ -8,71 +9,66 @@ from aws_xray_sdk.ext.util import calculate_sampling_decision, calculate_segment_name, construct_xray_header -async def middleware(app, handler): +@aiohttp.web.middleware +async def middleware(request, handler): """ - AioHttp Middleware Factory + Main middleware function, deals with all the X-Ray segment logic """ - async def _middleware(request): - """ - Main middleware function, deals with all the X-Ray segment logic - """ - # Create X-Ray headers - xray_header = construct_xray_header(request.headers) - # Get name of service or generate a dynamic one from host - name = calculate_segment_name(request.headers['host'].split(':', 1)[0], xray_recorder) + # Create X-Ray headers + xray_header = construct_xray_header(request.headers) + # Get name of service or generate a dynamic one from host + name = calculate_segment_name(request.headers['host'].split(':', 1)[0], xray_recorder) - sampling_decision = calculate_sampling_decision( - trace_header=xray_header, - recorder=xray_recorder, - service_name=request.headers['host'], - method=request.method, - path=request.path, - ) + sampling_decision = calculate_sampling_decision( + trace_header=xray_header, + recorder=xray_recorder, + service_name=request.headers['host'], + method=request.method, + path=request.path, + ) - # Start a segment - segment = xray_recorder.begin_segment( - name=name, - traceid=xray_header.root, - parent_id=xray_header.parent, - sampling=sampling_decision, - ) + # Start a segment + segment = xray_recorder.begin_segment( + name=name, + traceid=xray_header.root, + parent_id=xray_header.parent, + sampling=sampling_decision, + ) - # Store request metadata in the current segment - segment.put_http_meta(http.URL, request.url) - segment.put_http_meta(http.METHOD, request.method) + # Store request metadata in the current segment + segment.put_http_meta(http.URL, request.url) + segment.put_http_meta(http.METHOD, request.method) - if 'User-Agent' in request.headers: - segment.put_http_meta(http.USER_AGENT, request.headers['User-Agent']) + if 'User-Agent' in request.headers: + segment.put_http_meta(http.USER_AGENT, request.headers['User-Agent']) - if 'X-Forwarded-For' in request.headers: - segment.put_http_meta(http.CLIENT_IP, request.headers['X-Forwarded-For']) - segment.put_http_meta(http.X_FORWARDED_FOR, True) - elif 'remote_addr' in request.headers: - segment.put_http_meta(http.CLIENT_IP, request.headers['remote_addr']) - else: - segment.put_http_meta(http.CLIENT_IP, request.remote) + if 'X-Forwarded-For' in request.headers: + segment.put_http_meta(http.CLIENT_IP, request.headers['X-Forwarded-For']) + segment.put_http_meta(http.X_FORWARDED_FOR, True) + elif 'remote_addr' in request.headers: + segment.put_http_meta(http.CLIENT_IP, request.headers['remote_addr']) + else: + segment.put_http_meta(http.CLIENT_IP, request.remote) - try: - # Call next middleware or request handler - response = await handler(request) - except Exception as err: - # Store exception information including the stacktrace to the segment - segment = xray_recorder.current_segment() - segment.put_http_meta(http.STATUS, 500) - stack = traceback.extract_stack(limit=xray_recorder._max_trace_back) - segment.add_exception(err, stack) - xray_recorder.end_segment() - raise - - # Store response metadata into the current segment - segment.put_http_meta(http.STATUS, response.status) + try: + # Call next middleware or request handler + response = await handler(request) + except Exception as err: + # Store exception information including the stacktrace to the segment + segment = xray_recorder.current_segment() + segment.put_http_meta(http.STATUS, 500) + stack = traceback.extract_stack(limit=xray_recorder._max_trace_back) + segment.add_exception(err, stack) + xray_recorder.end_segment() + raise - if 'Content-Length' in response.headers: - length = int(response.headers['Content-Length']) - segment.put_http_meta(http.CONTENT_LENGTH, length) + # Store response metadata into the current segment + segment.put_http_meta(http.STATUS, response.status) - # Close segment so it can be dispatched off to the daemon - xray_recorder.end_segment() + if 'Content-Length' in response.headers: + length = int(response.headers['Content-Length']) + segment.put_http_meta(http.CONTENT_LENGTH, length) - return response - return _middleware + # Close segment so it can be dispatched off to the daemon + xray_recorder.end_segment() + return response diff --git a/docs/index.rst b/docs/index.rst index 7ac78bc6..ebd8e9e3 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -20,7 +20,7 @@ You can get started in minutes using ``pip`` or by downloading a zip file. Currently supported web frameworks and libraries: * aioboto3/aiobotocore -* aiohttp +* aiohttp >=2.3 * boto3/botocore * Django >=1.10 * Flask