diff --git a/aws_xray_sdk/core/models/entity.py b/aws_xray_sdk/core/models/entity.py index 648881cc..8b8bdf4f 100644 --- a/aws_xray_sdk/core/models/entity.py +++ b/aws_xray_sdk/core/models/entity.py @@ -18,6 +18,8 @@ _common_invalid_name_characters = '?;*()!$~^<>' _valid_annotation_key_characters = string.ascii_letters + string.digits + '_' +ORIGIN_TRACE_HEADER_ATTR_KEY = '_origin_trace_header' + class Entity(object): """ @@ -228,6 +230,20 @@ def add_exception(self, exception, stack, remote=False): self.cause['exceptions'] = exceptions self.cause['working_directory'] = os.getcwd() + def save_origin_trace_header(self, trace_header): + """ + Temporarily store additional data fields in trace header + to the entity for later propagation. The data will be + cleaned up upon serialization. + """ + setattr(self, ORIGIN_TRACE_HEADER_ATTR_KEY, trace_header) + + def get_origin_trace_header(self): + """ + Retrieve saved trace header data. + """ + return getattr(self, ORIGIN_TRACE_HEADER_ATTR_KEY, None) + def serialize(self): """ Serialize to JSON document that can be accepted by the @@ -258,6 +274,7 @@ def _delete_empty_properties(self, properties): del properties['annotations'] if not self.metadata: del properties['metadata'] + properties.pop(ORIGIN_TRACE_HEADER_ATTR_KEY, None) del properties['sampled'] diff --git a/aws_xray_sdk/core/models/segment.py b/aws_xray_sdk/core/models/segment.py index 3c52fc3c..b07d952c 100644 --- a/aws_xray_sdk/core/models/segment.py +++ b/aws_xray_sdk/core/models/segment.py @@ -155,20 +155,6 @@ def set_rule_name(self, rule_name): self.aws['xray'] = {} self.aws['xray']['sampling_rule_name'] = rule_name - def save_origin_trace_header(self, trace_header): - """ - Temporarily store additional data fields in trace header - to the segment for later propagation. The data will be - cleaned up upon serilaization. - """ - setattr(self, ORIGIN_TRACE_HEADER_ATTR_KEY, trace_header) - - def get_origin_trace_header(self): - """ - Retrieve saved trace header data. - """ - return getattr(self, ORIGIN_TRACE_HEADER_ATTR_KEY, None) - def __getstate__(self): """ Used by jsonpikle to remove unwanted fields. @@ -179,5 +165,4 @@ def __getstate__(self): del properties['user'] del properties['ref_counter'] del properties['_subsegments_counter'] - properties.pop(ORIGIN_TRACE_HEADER_ATTR_KEY, None) return properties diff --git a/aws_xray_sdk/ext/django/middleware.py b/aws_xray_sdk/ext/django/middleware.py index 9f071de9..c2905d30 100644 --- a/aws_xray_sdk/ext/django/middleware.py +++ b/aws_xray_sdk/ext/django/middleware.py @@ -5,6 +5,7 @@ from aws_xray_sdk.core.utils import stacktrace from aws_xray_sdk.ext.util import calculate_sampling_decision, \ calculate_segment_name, construct_xray_header, prepare_response_header +from aws_xray_sdk.core.lambda_launcher import check_in_lambda log = logging.getLogger(__name__) @@ -24,6 +25,10 @@ class XRayMiddleware(object): def __init__(self, get_response): self.get_response = get_response + self.in_lambda = False + + if check_in_lambda(): + self.in_lambda = True # hooks for django version >= 1.10 def __call__(self, request): @@ -46,12 +51,15 @@ def __call__(self, request): sampling_req=sampling_req, ) - segment = xray_recorder.begin_segment( - name=name, - traceid=xray_header.root, - parent_id=xray_header.parent, - sampling=sampling_decision, - ) + if self.in_lambda: + segment = xray_recorder.begin_subsegment(name) + else: + segment = xray_recorder.begin_segment( + name=name, + traceid=xray_header.root, + parent_id=xray_header.parent, + sampling=sampling_decision, + ) segment.save_origin_trace_header(xray_header) segment.put_http_meta(http.URL, request.build_absolute_uri()) @@ -75,7 +83,10 @@ def __call__(self, request): segment.put_http_meta(http.CONTENT_LENGTH, length) response[http.XRAY_HEADER] = prepare_response_header(xray_header, segment) - xray_recorder.end_segment() + if self.in_lambda: + xray_recorder.end_subsegment() + else: + xray_recorder.end_segment() return response diff --git a/aws_xray_sdk/ext/flask/middleware.py b/aws_xray_sdk/ext/flask/middleware.py index 9e0b4877..1a09913b 100644 --- a/aws_xray_sdk/ext/flask/middleware.py +++ b/aws_xray_sdk/ext/flask/middleware.py @@ -5,6 +5,7 @@ from aws_xray_sdk.core.utils import stacktrace from aws_xray_sdk.ext.util import calculate_sampling_decision, \ calculate_segment_name, construct_xray_header, prepare_response_header +from aws_xray_sdk.core.lambda_launcher import check_in_lambda class XRayMiddleware(object): @@ -17,6 +18,10 @@ def __init__(self, app, recorder): self.app.before_request(self._before_request) self.app.after_request(self._after_request) self.app.teardown_request(self._handle_exception) + self.in_lambda = False + + if check_in_lambda(): + self.in_lambda = True _patch_render(recorder) @@ -39,12 +44,15 @@ def _before_request(self): sampling_req=sampling_req, ) - segment = self._recorder.begin_segment( - name=name, - traceid=xray_header.root, - parent_id=xray_header.parent, - sampling=sampling_decision, - ) + if self.in_lambda: + segment = self._recorder.begin_subsegment(name) + else: + segment = self._recorder.begin_segment( + name=name, + traceid=xray_header.root, + parent_id=xray_header.parent, + sampling=sampling_decision, + ) segment.save_origin_trace_header(xray_header) segment.put_http_meta(http.URL, req.base_url) @@ -59,7 +67,10 @@ def _before_request(self): segment.put_http_meta(http.CLIENT_IP, req.remote_addr) def _after_request(self, response): - segment = self._recorder.current_segment() + if self.in_lambda: + segment = self._recorder.current_subsegment() + else: + segment = self._recorder.current_segment() segment.put_http_meta(http.STATUS, response.status_code) origin_header = segment.get_origin_trace_header() @@ -70,7 +81,10 @@ def _after_request(self, response): if cont_len: segment.put_http_meta(http.CONTENT_LENGTH, int(cont_len)) - self._recorder.end_segment() + if self.in_lambda: + self._recorder.end_subsegment() + else: + self._recorder.end_segment() return response def _handle_exception(self, exception): @@ -78,7 +92,10 @@ def _handle_exception(self, exception): return segment = None try: - segment = self._recorder.current_segment() + if self.in_lambda: + segment = self._recorder.current_subsegment() + else: + segment = self._recorder.current_segment() except Exception: pass if not segment: @@ -87,7 +104,10 @@ def _handle_exception(self, exception): segment.put_http_meta(http.STATUS, 500) stack = stacktrace.get_stacktrace(limit=self._recorder._max_trace_back) segment.add_exception(exception, stack) - self._recorder.end_segment() + if self.in_lambda: + self._recorder.end_subsegment() + else: + self._recorder.end_segment() def _patch_render(recorder): diff --git a/tests/ext/django/test_middleware.py b/tests/ext/django/test_middleware.py index a0128b7c..cb36ddf9 100644 --- a/tests/ext/django/test_middleware.py +++ b/tests/ext/django/test_middleware.py @@ -3,9 +3,11 @@ from django.core.urlresolvers import reverse from django.test import TestCase -from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.core import xray_recorder, lambda_launcher from aws_xray_sdk.core.context import Context -from aws_xray_sdk.core.models import http +from aws_xray_sdk.core.models import http, facade_segment +from tests.util import get_new_stubbed_recorder +import os class XRayTestCase(TestCase): @@ -111,3 +113,22 @@ def test_disabled_sdk(self): self.client.get(url) segment = xray_recorder.emitter.pop() assert not segment + + def test_lambda_serverless(self): + TRACE_ID = '1-5759e988-bd862e3fe1be46a994272793' + PARENT_ID = '53995c3f42cd8ad8' + HEADER_VAR = "Root=%s;Parent=%s;Sampled=1" % (TRACE_ID, PARENT_ID) + + os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY] = HEADER_VAR + lambda_context = lambda_launcher.LambdaContext() + + new_recorder = get_new_stubbed_recorder() + new_recorder.configure(service='test', sampling=False, context=lambda_context) + subsegment = new_recorder.begin_subsegment("subsegment") + assert type(subsegment.parent_segment) == facade_segment.FacadeSegment + new_recorder.end_subsegment() + + url = reverse('200ok') + self.client.get(url) + segment = new_recorder.emitter.pop() + assert not segment diff --git a/tests/ext/flask/test_flask.py b/tests/ext/flask/test_flask.py index 07c8d42c..b5bbf1f7 100644 --- a/tests/ext/flask/test_flask.py +++ b/tests/ext/flask/test_flask.py @@ -4,8 +4,10 @@ from aws_xray_sdk import global_sdk_config from aws_xray_sdk.ext.flask.middleware import XRayMiddleware from aws_xray_sdk.core.context import Context -from aws_xray_sdk.core.models import http +from aws_xray_sdk.core import lambda_launcher +from aws_xray_sdk.core.models import http, facade_segment from tests.util import get_new_stubbed_recorder +import os # define a flask app for testing purpose @@ -153,3 +155,45 @@ def test_disabled_sdk(): app.get(path) segment = recorder.emitter.pop() assert not segment + + +def test_lambda_serverless(): + TRACE_ID = '1-5759e988-bd862e3fe1be46a994272793' + PARENT_ID = '53995c3f42cd8ad8' + HEADER_VAR = "Root=%s;Parent=%s;Sampled=1" % (TRACE_ID, PARENT_ID) + + os.environ[lambda_launcher.LAMBDA_TRACE_HEADER_KEY] = HEADER_VAR + lambda_context = lambda_launcher.LambdaContext() + + new_recorder = get_new_stubbed_recorder() + new_recorder.configure(service='test', sampling=False, context=lambda_context) + new_app = Flask(__name__) + + @new_app.route('/subsegment') + def subsegment(): + # Test in between request and make sure Serverless creates a subsegment instead of a segment. + # Ensure that the parent segment is a facade segment. + assert new_recorder.current_subsegment() + assert type(new_recorder.current_segment()) == facade_segment.FacadeSegment + return 'ok' + + @new_app.route('/trace_header') + def trace_header(): + # Ensure trace header is preserved. + subsegment = new_recorder.current_subsegment() + header = subsegment.get_origin_trace_header() + assert header.data['k1'] == 'v1' + return 'ok' + + middleware = XRayMiddleware(new_app, new_recorder) + middleware.in_lambda = True + + app_client = new_app.test_client() + + path = '/subsegment' + app_client.get(path) + segment = recorder.emitter.pop() + assert not segment # Segment should be none because it's created and ended by the middleware + + path2 = '/trace_header' + app_client.get(path2, headers={http.XRAY_HEADER: 'k1=v1'})