Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to enable/disable the SDK (#26) #119

Merged
merged 8 commits into from
Feb 6, 2019
3 changes: 3 additions & 0 deletions aws_xray_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .sdk_config import SDKConfig

global_sdk_config = SDKConfig()
13 changes: 10 additions & 3 deletions aws_xray_sdk/core/lambda_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import logging
import threading

from aws_xray_sdk import global_sdk_config
from .models.facade_segment import FacadeSegment
from .models.trace_header import TraceHeader
from .context import Context


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -71,7 +71,8 @@ def put_subsegment(self, subsegment):
current_entity = self.get_trace_entity()

if not self._is_subsegment(current_entity) and current_entity.initializing:
log.warning("Subsegment %s discarded due to Lambda worker still initializing" % subsegment.name)
if sdk_config_module.sdk_enabled():
log.warning("Subsegment %s discarded due to Lambda worker still initializing" % subsegment.name)
return

current_entity.add_subsegment(subsegment)
Expand All @@ -93,6 +94,9 @@ def _refresh_context(self):
"""
header_str = os.getenv(LAMBDA_TRACE_HEADER_KEY)
trace_header = TraceHeader.from_header_str(header_str)
if not global_sdk_config.sdk_enabled():
trace_header._sampled = False

segment = getattr(self._local, 'segment', None)

if segment:
Expand Down Expand Up @@ -124,7 +128,10 @@ def _initialize_context(self, trace_header):
set by AWS Lambda and initialize storage for subsegments.
"""
sampled = None
if trace_header.sampled == 0:
if not global_sdk_config.sdk_enabled():
# Force subsequent subsegments to be disabled and turned into DummySegments.
sampled = False
elif trace_header.sampled == 0:
sampled = False
elif trace_header.sampled == 1:
sampled = True
Expand Down
5 changes: 5 additions & 0 deletions aws_xray_sdk/core/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import wrapt

from aws_xray_sdk import global_sdk_config
from .utils.compat import PY2, is_classmethod, is_instance_method

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -62,6 +63,10 @@ def _is_valid_import(module):


def patch(modules_to_patch, raise_errors=True, ignore_module_patterns=None):
enabled = global_sdk_config.sdk_enabled()
if not enabled:
chanchiem marked this conversation as resolved.
Show resolved Hide resolved
log.debug("Skipped patching modules %s because the SDK is currently disabled." % ', '.join(modules_to_patch))
return # Disable module patching if the SDK is disabled.
modules = set()
for module_to_patch in modules_to_patch:
# boto3 depends on botocore and patching botocore is sufficient
Expand Down
30 changes: 28 additions & 2 deletions aws_xray_sdk/core/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import platform
import time

from aws_xray_sdk import global_sdk_config
from aws_xray_sdk.version import VERSION
from .models.segment import Segment, SegmentContextManager
from .models.subsegment import Subsegment, SubsegmentContextManager
Expand All @@ -18,7 +19,7 @@
from .daemon_config import DaemonConfig
from .plugins.utils import get_plugin_modules
from .lambda_launcher import check_in_lambda
from .exceptions.exceptions import SegmentNameMissingException
from .exceptions.exceptions import SegmentNameMissingException, SegmentNotFoundException
from .utils.compat import string_types
from .utils import stacktrace

Expand Down Expand Up @@ -88,7 +89,6 @@ def configure(self, sampling=None, plugins=None,

Configure needs to run before patching thrid party libraries
to avoid creating dangling subsegment.

:param bool sampling: If sampling is enabled, every time the recorder
creates a segment it decides whether to send this segment to
the X-Ray daemon. This setting is not used if the recorder
Expand Down Expand Up @@ -138,6 +138,7 @@ class to have your own implementation of the streaming process.
and AWS_XRAY_TRACING_NAME respectively overrides arguments
daemon_address, context_missing and service.
"""

if sampling is not None:
self.sampling = sampling
if sampler:
Expand Down Expand Up @@ -219,6 +220,12 @@ def begin_segment(self, name=None, traceid=None,
# depending on if centralized or local sampling rule takes effect.
decision = True

# To disable the recorder, we set the sampling decision to always be false.
# This way, when segments are generated, they become dummy segments and are ultimately never sent.
# The call to self._sampler.should_trace() is never called either so the poller threads are never started.
if not global_sdk_config.sdk_enabled():
sampling = 0

# we respect the input sampling decision
# regardless of recorder configuration.
if sampling == 0:
Expand Down Expand Up @@ -273,6 +280,7 @@ def begin_subsegment(self, name, namespace='local'):
:param str name: the name of the subsegment.
:param str namespace: currently can only be 'local', 'remote', 'aws'.
"""

segment = self.current_segment()
if not segment:
log.warning("No segment found, cannot begin subsegment %s." % name)
Expand Down Expand Up @@ -396,6 +404,16 @@ def capture(self, name=None):
def record_subsegment(self, wrapped, instance, args, kwargs, name,
namespace, meta_processor):

# In the case when the SDK is disabled, we ensure that a parent segment exists, because this is usually
# handled by the middleware. We generate a dummy segment as the parent segment if one doesn't exist.
# This is to allow potential segment method calls to not throw exceptions in the captured method.
if not global_sdk_config.sdk_enabled():
try:
self.current_segment()
except SegmentNotFoundException:
segment = DummySegment(name)
self.context.put_segment(segment)

subsegment = self.begin_subsegment(name, namespace)

exception = None
Expand Down Expand Up @@ -473,6 +491,14 @@ def _is_subsegment(self, entity):

return (hasattr(entity, 'type') and entity.type == 'subsegment')

@property
def enabled(self):
return self._enabled

@enabled.setter
def enabled(self, value):
self._enabled = value

@property
def sampling(self):
return self._sampling
Expand Down
7 changes: 7 additions & 0 deletions aws_xray_sdk/core/sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .target_poller import TargetPoller
from .connector import ServiceConnector
from .reservoir import ReservoirDecision
from aws_xray_sdk import global_sdk_config

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -37,6 +38,9 @@ def start(self):
Start rule poller and target poller once X-Ray daemon address
and context manager is in place.
"""
if not global_sdk_config.sdk_enabled():
return

with self._lock:
if not self._started:
self._rule_poller.start()
Expand All @@ -51,6 +55,9 @@ def should_trace(self, sampling_req=None):
All optional arguments are extracted from incoming requests by
X-Ray middleware to perform path based sampling.
"""
if not global_sdk_config.sdk_enabled():
return False

if not self._started:
self.start() # only front-end that actually uses the sampler spawns poller threads

Expand Down
58 changes: 58 additions & 0 deletions aws_xray_sdk/sdk_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os
import logging

log = logging.getLogger(__name__)


class SDKConfig(object):
"""
Global Configuration Class that defines SDK-level configuration properties.

Enabling/Disabling the SDK:
By default, the SDK is enabled unless if an environment variable AWS_XRAY_SDK_ENABLED
is set. If it is set, it needs to be a valid string boolean, otherwise, it will default
to true. If the environment variable is set, all calls to set_sdk_enabled() will
prioritize the value of the environment variable.
Disabling the SDK affects the recorder, patcher, and middlewares in the following ways:
For the recorder, disabling automatically generates DummySegments for subsequent segments
and DummySubsegments for subsegments created and thus not send any traces to the daemon.
For the patcher, module patching will automatically be disabled. The SDK must be disabled
before calling patcher.patch() method in order for this to function properly.
For the middleware, no modification is made on them, but since the recorder automatically
generates DummySegments for all subsequent calls, they will not generate segments/subsegments
to be sent.

Environment variables:
"AWS_XRAY_SDK_ENABLED" - If set to 'false' disables the SDK and causes the explained above
to occur.
"""
XRAY_ENABLED_KEY = 'AWS_XRAY_SDK_ENABLED'
__SDK_ENABLED = str(os.getenv(XRAY_ENABLED_KEY, 'true')).lower() != 'false'

@classmethod
def sdk_enabled(cls):
"""
Returns whether the SDK is enabled or not.
"""
return cls.__SDK_ENABLED

@classmethod
def set_sdk_enabled(cls, value):
"""
Modifies the enabled flag if the "AWS_XRAY_SDK_ENABLED" environment variable is not set,
otherwise, set the enabled flag to be equal to the environment variable. If the
env variable is an invalid string boolean, it will default to true.

:param bool value: Flag to set whether the SDK is enabled or disabled.

Environment variables AWS_XRAY_SDK_ENABLED overrides argument value.
"""
# Environment Variables take precedence over hardcoded configurations.
if cls.XRAY_ENABLED_KEY in os.environ:
cls.__SDK_ENABLED = str(os.getenv(cls.XRAY_ENABLED_KEY, 'true')).lower() != 'false'
else:
if type(value) == bool:
cls.__SDK_ENABLED = value
else:
cls.__SDK_ENABLED = True
log.warning("Invalid parameter type passed into set_sdk_enabled(). Defaulting to True...")
20 changes: 20 additions & 0 deletions tests/ext/aiohttp/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Expects pytest-aiohttp
"""
import asyncio
from aws_xray_sdk import global_sdk_config
from unittest.mock import patch

from aiohttp import web
Expand Down Expand Up @@ -109,6 +110,7 @@ def recorder(loop):

xray_recorder.clear_trace_entities()
yield xray_recorder
global_sdk_config.set_sdk_enabled(True)
xray_recorder.clear_trace_entities()
patcher.stop()

Expand Down Expand Up @@ -283,3 +285,21 @@ async def get_delay():
# Ensure all ID's are different
ids = [item.id for item in recorder.emitter.local]
assert len(ids) == len(set(ids))


async def test_disabled_sdk(test_client, loop, recorder):
"""
Test a normal response when the SDK is disabled.

:param test_client: AioHttp test client fixture
:param loop: Eventloop fixture
:param recorder: X-Ray recorder fixture
"""
global_sdk_config.set_sdk_enabled(False)
client = await test_client(ServerTest.app(loop=loop))

resp = await client.get('/')
assert resp.status == 200

segment = recorder.emitter.pop()
assert not segment
9 changes: 9 additions & 0 deletions tests/ext/django/test_middleware.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import django
from aws_xray_sdk import global_sdk_config
from django.core.urlresolvers import reverse
from django.test import TestCase

Expand All @@ -14,6 +15,7 @@ def setUp(self):
xray_recorder.configure(context=Context(),
context_missing='LOG_ERROR')
xray_recorder.clear_trace_entities()
global_sdk_config.set_sdk_enabled(True)

def tearDown(self):
xray_recorder.clear_trace_entities()
Expand Down Expand Up @@ -102,3 +104,10 @@ def test_response_header(self):

assert 'Sampled=1' in trace_header
assert segment.trace_id in trace_header

def test_disabled_sdk(self):
global_sdk_config.set_sdk_enabled(False)
url = reverse('200ok')
self.client.get(url)
segment = xray_recorder.emitter.pop()
assert not segment
10 changes: 10 additions & 0 deletions tests/ext/flask/test_flask.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from flask import Flask, render_template_string

from aws_xray_sdk import global_sdk_config
from aws_xray_sdk.ext.flask.middleware import XRayMiddleware
from aws_xray_sdk.core.context import Context
from aws_xray_sdk.core.models import http
Expand Down Expand Up @@ -51,6 +52,7 @@ def cleanup():
recorder.clear_trace_entities()
yield
recorder.clear_trace_entities()
global_sdk_config.set_sdk_enabled(True)


def test_ok():
Expand Down Expand Up @@ -143,3 +145,11 @@ def test_sampled_response_header():
resp_header = resp.headers[http.XRAY_HEADER]
assert segment.trace_id in resp_header
assert 'Sampled=1' in resp_header


def test_disabled_sdk():
global_sdk_config.set_sdk_enabled(False)
path = '/ok'
app.get(path)
segment = recorder.emitter.pop()
assert not segment
19 changes: 19 additions & 0 deletions tests/test_lambda_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from aws_xray_sdk import global_sdk_config
import pytest
from aws_xray_sdk.core import lambda_launcher
from aws_xray_sdk.core.models.subsegment import Subsegment

Expand All @@ -12,6 +14,12 @@
context = lambda_launcher.LambdaContext()


@pytest.fixture(autouse=True)
def setup():
yield
global_sdk_config.set_sdk_enabled(True)


def test_facade_segment_generation():

segment = context.get_trace_entity()
Expand Down Expand Up @@ -41,3 +49,14 @@ def test_put_subsegment():

context.end_subsegment()
assert context.get_trace_entity().id == segment.id


def test_disable():
context.clear_trace_entities()
segment = context.get_trace_entity()
assert segment.sampled

context.clear_trace_entities()
global_sdk_config.set_sdk_enabled(False)
segment = context.get_trace_entity()
assert not segment.sampled
10 changes: 10 additions & 0 deletions tests/test_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# Python versions < 3 have reload built-in
pass

from aws_xray_sdk import global_sdk_config
from aws_xray_sdk.core import patcher, xray_recorder
from aws_xray_sdk.core.context import Context

Expand Down Expand Up @@ -40,6 +41,7 @@ def construct_ctx():
yield
xray_recorder.end_segment()
xray_recorder.clear_trace_entities()
global_sdk_config.set_sdk_enabled(True)

# Reload wrapt.importer references to modules to start off clean
reload(wrapt)
Expand Down Expand Up @@ -172,3 +174,11 @@ def test_external_submodules_ignores_module():
assert xray_recorder.current_segment().subsegments[0].name == 'mock_init'
assert xray_recorder.current_segment().subsegments[1].name == 'mock_func'
assert xray_recorder.current_segment().subsegments[2].name == 'mock_no_doublepatch' # It is patched with decorator


def test_disable_sdk_disables_patching():
global_sdk_config.set_sdk_enabled(False)
patcher.patch(['tests.mock_module'])
imported_modules = [module for module in TEST_MODULES if module in sys.modules]
assert not imported_modules
assert len(xray_recorder.current_segment().subsegments) == 0
Loading