Skip to content

Add ability to enable/disable the SDK (#26) #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 6, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aws_xray_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .sdk_config import SDKConfig

global_sdk_config = SDKConfig()
13 changes: 10 additions & 3 deletions aws_xray_sdk/core/lambda_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import logging
import threading

import aws_xray_sdk
from .models.facade_segment import FacadeSegment
from .models.trace_header import TraceHeader
from .context import Context


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -71,7 +71,8 @@ def put_subsegment(self, subsegment):
current_entity = self.get_trace_entity()

if not self._is_subsegment(current_entity) and current_entity.initializing:
log.warning("Subsegment %s discarded due to Lambda worker still initializing" % subsegment.name)
if sdk_config_module.sdk_enabled():
log.warning("Subsegment %s discarded due to Lambda worker still initializing" % subsegment.name)
return

current_entity.add_subsegment(subsegment)
Expand All @@ -93,6 +94,9 @@ def _refresh_context(self):
"""
header_str = os.getenv(LAMBDA_TRACE_HEADER_KEY)
trace_header = TraceHeader.from_header_str(header_str)
if not aws_xray_sdk.global_sdk_config.sdk_enabled():
trace_header._sampled = False

segment = getattr(self._local, 'segment', None)

if segment:
Expand Down Expand Up @@ -124,7 +128,10 @@ def _initialize_context(self, trace_header):
set by AWS Lambda and initialize storage for subsegments.
"""
sampled = None
if trace_header.sampled == 0:
if not aws_xray_sdk.global_sdk_config.sdk_enabled():
# Force subsequent subsegments to be disabled and turned into DummySegments.
sampled = False
elif trace_header.sampled == 0:
sampled = False
elif trace_header.sampled == 1:
sampled = True
Expand Down
4 changes: 4 additions & 0 deletions aws_xray_sdk/core/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import wrapt

import aws_xray_sdk
from .utils.compat import PY2, is_classmethod, is_instance_method

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,6 +61,9 @@ def _is_valid_import(module):


def patch(modules_to_patch, raise_errors=True, ignore_module_patterns=None):
enabled = aws_xray_sdk.global_sdk_config.sdk_enabled()
if not enabled:
return # Disable module patching if the SDK is disabled.
modules = set()
for module_to_patch in modules_to_patch:
# boto3 depends on botocore and patching botocore is sufficient
Expand Down
31 changes: 29 additions & 2 deletions aws_xray_sdk/core/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import platform
import time

import aws_xray_sdk
from aws_xray_sdk.version import VERSION
from .models.segment import Segment, SegmentContextManager
from .models.subsegment import Subsegment, SubsegmentContextManager
Expand All @@ -18,12 +19,13 @@
from .daemon_config import DaemonConfig
from .plugins.utils import get_plugin_modules
from .lambda_launcher import check_in_lambda
from .exceptions.exceptions import SegmentNameMissingException
from .exceptions.exceptions import SegmentNameMissingException, SegmentNotFoundException
from .utils.compat import string_types
from .utils import stacktrace

log = logging.getLogger(__name__)

XRAY_ENABLED_KEY = 'AWS_XRAY_ENABLED'
TRACING_NAME_KEY = 'AWS_XRAY_TRACING_NAME'
DAEMON_ADDR_KEY = 'AWS_XRAY_DAEMON_ADDRESS'
CONTEXT_MISSING_KEY = 'AWS_XRAY_CONTEXT_MISSING'
Expand Down Expand Up @@ -88,7 +90,6 @@ def configure(self, sampling=None, plugins=None,

Configure needs to run before patching thrid party libraries
to avoid creating dangling subsegment.

:param bool sampling: If sampling is enabled, every time the recorder
creates a segment it decides whether to send this segment to
the X-Ray daemon. This setting is not used if the recorder
Expand Down Expand Up @@ -138,6 +139,7 @@ class to have your own implementation of the streaming process.
and AWS_XRAY_TRACING_NAME respectively overrides arguments
daemon_address, context_missing and service.
"""

if sampling is not None:
self.sampling = sampling
if sampler:
Expand Down Expand Up @@ -219,6 +221,12 @@ def begin_segment(self, name=None, traceid=None,
# depending on if centralized or local sampling rule takes effect.
decision = True

# To disable the recorder, we set the sampling decision to always be false.
# This way, when segments are generated, they become dummy segments and are ultimately never sent.
# The call to self._sampler.should_trace() is never called either so the poller threads are never started.
if not aws_xray_sdk.global_sdk_config.sdk_enabled():
sampling = 0

# we respect the input sampling decision
# regardless of recorder configuration.
if sampling == 0:
Expand Down Expand Up @@ -273,6 +281,7 @@ def begin_subsegment(self, name, namespace='local'):
:param str name: the name of the subsegment.
:param str namespace: currently can only be 'local', 'remote', 'aws'.
"""

segment = self.current_segment()
if not segment:
log.warning("No segment found, cannot begin subsegment %s." % name)
Expand Down Expand Up @@ -396,6 +405,16 @@ def capture(self, name=None):
def record_subsegment(self, wrapped, instance, args, kwargs, name,
namespace, meta_processor):

# In the case when the SDK is disabled, we ensure that a parent segment exists, because this is usually
# handled by the middleware. We generate a dummy segment as the parent segment if one doesn't exist.
# This is to allow potential segment method calls to not throw exceptions in the captured method.
if not aws_xray_sdk.global_sdk_config.sdk_enabled():
try:
self.current_segment()
except SegmentNotFoundException:
segment = DummySegment(name)
self.context.put_segment(segment)

subsegment = self.begin_subsegment(name, namespace)

exception = None
Expand Down Expand Up @@ -473,6 +492,14 @@ def _is_subsegment(self, entity):

return (hasattr(entity, 'type') and entity.type == 'subsegment')

@property
def enabled(self):
return self._enabled

@enabled.setter
def enabled(self, value):
self._enabled = value

@property
def sampling(self):
return self._sampling
Expand Down
7 changes: 7 additions & 0 deletions aws_xray_sdk/core/sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .target_poller import TargetPoller
from .connector import ServiceConnector
from .reservoir import ReservoirDecision
import aws_xray_sdk

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -37,6 +38,9 @@ def start(self):
Start rule poller and target poller once X-Ray daemon address
and context manager is in place.
"""
if not aws_xray_sdk.global_sdk_config.sdk_enabled():
return

with self._lock:
if not self._started:
self._rule_poller.start()
Expand All @@ -51,6 +55,9 @@ def should_trace(self, sampling_req=None):
All optional arguments are extracted from incoming requests by
X-Ray middleware to perform path based sampling.
"""
if not aws_xray_sdk.global_sdk_config.sdk_enabled():
return False

if not self._started:
self.start() # only front-end that actually uses the sampler spawns poller threads

Expand Down
64 changes: 64 additions & 0 deletions aws_xray_sdk/sdk_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os


class InvalidParameterTypeException(Exception):
"""
Exception thrown when an invalid parameter is passed into SDKConfig.set_sdk_enabled.
"""
pass


class SDKConfig(object):
"""
Global Configuration Class that defines SDK-level configuration properties.

Enabling/Disabling the SDK:
By default, the SDK is enabled unless if an environment variable AWS_XRAY_SDK_ENABLED
is set. If it is set, it needs to be a valid string boolean, otherwise, it will default
to true. If the environment variable is set, all calls to set_sdk_enabled() will
prioritize the value of the environment variable.
Disabling the SDK affects the recorder, patcher, and middlewares in the following ways:
For the recorder, disabling automatically generates DummySegments for subsequent segments
and DummySubsegments for subsegments created and thus not send any traces to the daemon.
For the patcher, module patching will automatically be disabled. The SDK must be disabled
before calling patcher.patch() method in order for this to function properly.
For the middleware, no modification is made on them, but since the recorder automatically
generates DummySegments for all subsequent calls, they will not generate segments/subsegments
to be sent.

Environment variables:
"AWS_XRAY_SDK_ENABLED" - If set to 'false' disables the SDK and causes the explained above
to occur.
"""
XRAY_ENABLED_KEY = 'AWS_XRAY_SDK_ENABLED'
__SDK_ENABLED = str(os.getenv(XRAY_ENABLED_KEY, 'true')).lower() != 'false'

@classmethod
def sdk_enabled(cls):
"""
Returns whether the SDK is enabled or not.
"""
return cls.__SDK_ENABLED

@classmethod
def set_sdk_enabled(cls, value):
"""
Modifies the enabled flag if the "AWS_XRAY_SDK_ENABLED" environment variable is not set,
otherwise, set the enabled flag to be equal to the environment variable. If the
env variable is an invalid string boolean, it will default to true.

:param bool value: Flag to set whether the SDK is enabled or disabled.

Environment variables AWS_XRAY_SDK_ENABLED overrides argument value.
"""
# Environment Variables take precedence over hardcoded configurations.
if cls.XRAY_ENABLED_KEY in os.environ:
cls.__SDK_ENABLED = str(os.getenv(cls.XRAY_ENABLED_KEY, 'true')).lower() != 'false'
else:
if type(value) == bool:
cls.__SDK_ENABLED = value
else:
cls.__SDK_ENABLED = True
raise InvalidParameterTypeException(
"Invalid parameter type passed into set_sdk_enabled(). Defaulting to True..."
)
20 changes: 20 additions & 0 deletions tests/ext/aiohttp/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Expects pytest-aiohttp
"""
import asyncio
import aws_xray_sdk
from unittest.mock import patch

from aiohttp import web
Expand Down Expand Up @@ -109,6 +110,7 @@ def recorder(loop):

xray_recorder.clear_trace_entities()
yield xray_recorder
aws_xray_sdk.global_sdk_config.set_sdk_enabled(True)
xray_recorder.clear_trace_entities()
patcher.stop()

Expand Down Expand Up @@ -283,3 +285,21 @@ async def get_delay():
# Ensure all ID's are different
ids = [item.id for item in recorder.emitter.local]
assert len(ids) == len(set(ids))


async def test_disabled_sdk(test_client, loop, recorder):
"""
Test a normal response when the SDK is disabled.

:param test_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
aws_xray_sdk.global_sdk_config.set_sdk_enabled(False)
client = await test_client(ServerTest.app(loop=loop))

resp = await client.get('/')
assert resp.status == 200

segment = recorder.emitter.pop()
assert not segment
9 changes: 9 additions & 0 deletions tests/ext/django/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import django
import aws_xray_sdk
from django.core.urlresolvers import reverse
from django.test import TestCase

Expand All @@ -14,6 +15,7 @@ def setUp(self):
xray_recorder.configure(context=Context(),
context_missing='LOG_ERROR')
xray_recorder.clear_trace_entities()
aws_xray_sdk.global_sdk_config.set_sdk_enabled(True)

def tearDown(self):
xray_recorder.clear_trace_entities()
Expand Down Expand Up @@ -102,3 +104,10 @@ def test_response_header(self):

assert 'Sampled=1' in trace_header
assert segment.trace_id in trace_header

def test_disabled_sdk(self):
aws_xray_sdk.global_sdk_config.set_sdk_enabled(False)
url = reverse('200ok')
self.client.get(url)
segment = xray_recorder.emitter.pop()
assert not segment
10 changes: 10 additions & 0 deletions tests/ext/flask/test_flask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from flask import Flask, render_template_string

import aws_xray_sdk
from aws_xray_sdk.ext.flask.middleware import XRayMiddleware
from aws_xray_sdk.core.context import Context
from aws_xray_sdk.core.models import http
Expand Down Expand Up @@ -51,6 +52,7 @@ def cleanup():
recorder.clear_trace_entities()
yield
recorder.clear_trace_entities()
aws_xray_sdk.global_sdk_config.set_sdk_enabled(True)


def test_ok():
Expand Down Expand Up @@ -143,3 +145,11 @@ def test_sampled_response_header():
resp_header = resp.headers[http.XRAY_HEADER]
assert segment.trace_id in resp_header
assert 'Sampled=1' in resp_header


def test_disabled_sdk():
aws_xray_sdk.global_sdk_config.set_sdk_enabled(False)
path = '/ok'
app.get(path)
segment = recorder.emitter.pop()
assert not segment
19 changes: 19 additions & 0 deletions tests/test_lambda_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

import aws_xray_sdk
import pytest
from aws_xray_sdk.core import lambda_launcher
from aws_xray_sdk.core.models.subsegment import Subsegment

Expand All @@ -12,6 +14,12 @@
context = lambda_launcher.LambdaContext()


@pytest.fixture(autouse=True)
def setup():
yield
aws_xray_sdk.global_sdk_config.set_sdk_enabled(True)


def test_facade_segment_generation():

segment = context.get_trace_entity()
Expand Down Expand Up @@ -41,3 +49,14 @@ def test_put_subsegment():

context.end_subsegment()
assert context.get_trace_entity().id == segment.id


def test_disable():
context.clear_trace_entities()
segment = context.get_trace_entity()
assert segment.sampled

context.clear_trace_entities()
aws_xray_sdk.global_sdk_config.set_sdk_enabled(False)
segment = context.get_trace_entity()
assert not segment.sampled
Loading