diff --git a/examples/opentelemetry-example-app/src/opentelemetry_example_app/context_propagation_example.py b/examples/opentelemetry-example-app/src/opentelemetry_example_app/context_propagation_example.py new file mode 100644 index 0000000000..90217b2207 --- /dev/null +++ b/examples/opentelemetry-example-app/src/opentelemetry_example_app/context_propagation_example.py @@ -0,0 +1,98 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +This module serves as an example for baggage, which exists +to pass application-defined key-value pairs from service to service. +""" +# import opentelemetry.ext.http_requests +# from opentelemetry.ext.wsgi import OpenTelemetryMiddleware + +import flask +import requests +from flask import request + +import opentelemetry.ext.http_requests +from opentelemetry import propagation, trace +from opentelemetry.correlationcontext import CorrelationContextManager +from opentelemetry.ext.wsgi import OpenTelemetryMiddleware +from opentelemetry.sdk.context.propagation import b3_format +from opentelemetry.sdk.trace import TracerSource +from opentelemetry.sdk.trace.export import ( + BatchExportSpanProcessor, + ConsoleSpanExporter, +) + + +def configure_opentelemetry(flask_app: flask.Flask): + trace.set_preferred_tracer_source_implementation(lambda T: TracerSource()) + trace.tracer_source().add_span_processor( + BatchExportSpanProcessor(ConsoleSpanExporter()) + ) + + # Global initialization + (b3_extractor, b3_injector) = b3_format.http_propagator() + # propagation.set_http_extractors([b3_extractor, baggage_extractor]) + # propagation.set_http_injectors([b3_injector, baggage_injector]) + propagation.set_http_extractors([b3_extractor]) + propagation.set_http_injectors([b3_injector]) + + opentelemetry.ext.http_requests.enable(trace.tracer_source()) + flask_app.wsgi_app = OpenTelemetryMiddleware(flask_app.wsgi_app) + + +def fetch_from_service_b() -> str: + with trace.tracer_source().get_tracer(__name__).start_as_current_span( + "fetch_from_service_b" + ): + # Inject the contexts to be propagated. Note that there is no direct + # reference to tracing or baggage. + headers = {"Accept": "text/html"} + propagation.inject(headers) + resp = requests.get("https://opentelemetry.io", headers=headers) + return resp.text + + +def fetch_from_service_c() -> str: + with trace.tracer_source().get_tracer(__name__).start_as_current_span( + "fetch_from_service_c" + ): + # Inject the contexts to be propagated. Note that there is no direct + # reference to tracing or baggage. + headers = {"Accept": "application/json"} + propagation.inject(headers) + resp = requests.get("https://opentelemetry.io", headers=headers) + return resp.text + + +app = flask.Flask(__name__) + + +@app.route("/") +def hello(): + tracer = trace.tracer_source().get_tracer(__name__) + # extract a baggage header + propagation.extract(request.headers) + + with tracer.start_as_current_span("service-span"): + with tracer.start_as_current_span("external-req-span"): + version = CorrelationContextManager.correlation("version") + if version == "2.0": + return fetch_from_service_c() + return fetch_from_service_b() + + +if __name__ == "__main__": + configure_opentelemetry(app) + app.run(debug=True) diff --git a/examples/opentelemetry-example-app/src/opentelemetry_example_app/flask_example.py b/examples/opentelemetry-example-app/src/opentelemetry_example_app/flask_example.py index ae484dd30e..46ed601fc9 100644 --- a/examples/opentelemetry-example-app/src/opentelemetry_example_app/flask_example.py +++ b/examples/opentelemetry-example-app/src/opentelemetry_example_app/flask_example.py @@ -21,9 +21,12 @@ import requests import opentelemetry.ext.http_requests -from opentelemetry import propagators, trace +from opentelemetry import propagation, trace from opentelemetry.ext.flask import instrument_app -from opentelemetry.sdk.context.propagation.b3_format import B3Format +from opentelemetry.sdk.context.propagation.b3_format import ( + B3Extractor, + B3Injector, +) from opentelemetry.sdk.trace import TracerSource @@ -52,7 +55,8 @@ def configure_opentelemetry(flask_app: flask.Flask): # carry this value). # TBD: can remove once default TraceContext propagators are installed. - propagators.set_global_httptextformat(B3Format()) + propagation.set_http_extractors([B3Extractor]) + propagation.set_http_injectors([B3Injector]) # Integrations are the glue that binds the OpenTelemetry API # and the frameworks and libraries that are used together, automatically diff --git a/ext/opentelemetry-ext-flask/src/opentelemetry/ext/flask/__init__.py b/ext/opentelemetry-ext-flask/src/opentelemetry/ext/flask/__init__.py index ce11b18d63..25f9865b5e 100644 --- a/ext/opentelemetry-ext-flask/src/opentelemetry/ext/flask/__init__.py +++ b/ext/opentelemetry-ext-flask/src/opentelemetry/ext/flask/__init__.py @@ -6,8 +6,9 @@ from flask import request as flask_request import opentelemetry.ext.wsgi as otel_wsgi -from opentelemetry import propagators, trace +from opentelemetry import propagation, trace from opentelemetry.ext.flask.version import __version__ +from opentelemetry.trace.propagation.context import span_context_from_context from opentelemetry.util import time_ns logger = logging.getLogger(__name__) @@ -57,10 +58,12 @@ def _before_flask_request(): span_name = flask_request.endpoint or otel_wsgi.get_default_span_name( environ ) - parent_span = propagators.extract( - otel_wsgi.get_header_from_environ, environ + + propagation.extract( + environ, get_from_carrier=otel_wsgi.get_header_from_environ ) + parent_span = span_context_from_context() tracer = trace.tracer_source().get_tracer(__name__, __version__) attributes = otel_wsgi.collect_request_attributes(environ) diff --git a/ext/opentelemetry-ext-http-requests/src/opentelemetry/ext/http_requests/__init__.py b/ext/opentelemetry-ext-http-requests/src/opentelemetry/ext/http_requests/__init__.py index 4f5a18cf9e..5595b9fe8f 100644 --- a/ext/opentelemetry-ext-http-requests/src/opentelemetry/ext/http_requests/__init__.py +++ b/ext/opentelemetry-ext-http-requests/src/opentelemetry/ext/http_requests/__init__.py @@ -22,8 +22,8 @@ from requests.sessions import Session -from opentelemetry import propagators -from opentelemetry.context import Context +from opentelemetry import propagation +from opentelemetry.context import current from opentelemetry.ext.http_requests.version import __version__ from opentelemetry.trace import SpanKind @@ -54,7 +54,7 @@ def enable(tracer_source): @functools.wraps(wrapped) def instrumented_request(self, method, url, *args, **kwargs): - if Context.suppress_instrumentation: + if current().value("suppress_instrumentation"): return wrapped(self, method, url, *args, **kwargs) # See @@ -77,7 +77,7 @@ def instrumented_request(self, method, url, *args, **kwargs): # to access propagators. headers = kwargs.setdefault("headers", {}) - propagators.inject(tracer, type(headers).__setitem__, headers) + propagation.inject(headers) result = wrapped(self, method, url, *args, **kwargs) # *** PROCEED span.set_attribute("http.status_code", result.status_code) diff --git a/ext/opentelemetry-ext-opentracing-shim/src/opentelemetry/ext/opentracing_shim/__init__.py b/ext/opentelemetry-ext-opentracing-shim/src/opentelemetry/ext/opentracing_shim/__init__.py index 7c7640017b..5b65ed2d34 100644 --- a/ext/opentelemetry-ext-opentracing-shim/src/opentelemetry/ext/opentracing_shim/__init__.py +++ b/ext/opentelemetry-ext-opentracing-shim/src/opentelemetry/ext/opentracing_shim/__init__.py @@ -86,9 +86,14 @@ from deprecated import deprecated import opentelemetry.trace as trace_api -from opentelemetry import propagators +from opentelemetry import propagation +from opentelemetry.context import Context from opentelemetry.ext.opentracing_shim import util from opentelemetry.ext.opentracing_shim.version import __version__ +from opentelemetry.trace.propagation.context import ( + span_context_from_context, + with_span_context, +) logger = logging.getLogger(__name__) @@ -664,16 +669,14 @@ def inject(self, span_context, format, carrier): # TODO: Finish documentation. # pylint: disable=redefined-builtin # This implementation does not perform the injecting by itself but - # uses the configured propagators in opentelemetry.propagators. + # uses the configured propagation in opentelemetry.propagation. # TODO: Support Format.BINARY once it is supported in # opentelemetry-python. if format not in self._supported_formats: raise opentracing.UnsupportedFormatException - propagator = propagators.get_global_httptextformat() - propagator.inject( - span_context.unwrap(), type(carrier).__setitem__, carrier - ) + ctx = with_span_context(span_context.unwrap()) + propagation.inject(carrier, context=ctx) def extract(self, format, carrier): """Implements the ``extract`` method from the base class.""" @@ -681,17 +684,13 @@ def extract(self, format, carrier): # TODO: Finish documentation. # pylint: disable=redefined-builtin # This implementation does not perform the extracing by itself but - # uses the configured propagators in opentelemetry.propagators. + # uses the configured propagation in opentelemetry.propagation. # TODO: Support Format.BINARY once it is supported in # opentelemetry-python. if format not in self._supported_formats: raise opentracing.UnsupportedFormatException - def get_as_list(dict_object, key): - value = dict_object.get(key) - return [value] if value is not None else [] - - propagator = propagators.get_global_httptextformat() - otel_context = propagator.extract(get_as_list, carrier) + propagation.extract(carrier) + otel_context = span_context_from_context() return SpanContextShim(otel_context) diff --git a/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py b/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py index d42098dce7..2a90221119 100644 --- a/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py +++ b/ext/opentelemetry-ext-opentracing-shim/tests/test_shim.py @@ -18,10 +18,19 @@ import opentracing import opentelemetry.ext.opentracing_shim as opentracingshim -from opentelemetry import propagators, trace -from opentelemetry.context.propagation.httptextformat import HTTPTextFormat +from opentelemetry import propagation, trace from opentelemetry.ext.opentracing_shim import util +from opentelemetry.propagation import ( + Extractor, + Injector, + get_as_list, + set_in_dict, +) from opentelemetry.sdk.trace import TracerSource +from opentelemetry.trace.propagation.context import ( + span_context_from_context, + with_span_context, +) class TestShim(unittest.TestCase): @@ -43,15 +52,18 @@ def setUpClass(cls): ) # Save current propagator to be restored on teardown. - cls._previous_propagator = propagators.get_global_httptextformat() + cls._previous_injectors = propagation.get_http_injectors() + cls._previous_extractors = propagation.get_http_extractors() # Set mock propagator for testing. - propagators.set_global_httptextformat(MockHTTPTextFormat) + propagation.set_http_extractors([MockHTTPExtractor]) + propagation.set_http_injectors([MockHTTPInjector]) @classmethod def tearDownClass(cls): # Restore previous propagator. - propagators.set_global_httptextformat(cls._previous_propagator) + propagation.set_http_extractors(cls._previous_extractors) + propagation.set_http_injectors(cls._previous_injectors) def test_shim_type(self): # Verify shim is an OpenTracing tracer. @@ -475,8 +487,8 @@ def test_inject_http_headers(self): headers = {} self.shim.inject(context, opentracing.Format.HTTP_HEADERS, headers) - self.assertEqual(headers[MockHTTPTextFormat.TRACE_ID_KEY], str(1220)) - self.assertEqual(headers[MockHTTPTextFormat.SPAN_ID_KEY], str(7478)) + self.assertEqual(headers[_TRACE_ID_KEY], str(1220)) + self.assertEqual(headers[_SPAN_ID_KEY], str(7478)) def test_inject_text_map(self): """Test `inject()` method for Format.TEXT_MAP.""" @@ -487,8 +499,8 @@ def test_inject_text_map(self): # Verify Format.TEXT_MAP text_map = {} self.shim.inject(context, opentracing.Format.TEXT_MAP, text_map) - self.assertEqual(text_map[MockHTTPTextFormat.TRACE_ID_KEY], str(1220)) - self.assertEqual(text_map[MockHTTPTextFormat.SPAN_ID_KEY], str(7478)) + self.assertEqual(text_map[_TRACE_ID_KEY], str(1220)) + self.assertEqual(text_map[_SPAN_ID_KEY], str(7478)) def test_inject_binary(self): """Test `inject()` method for Format.BINARY.""" @@ -504,8 +516,8 @@ def test_extract_http_headers(self): """Test `extract()` method for Format.HTTP_HEADERS.""" carrier = { - MockHTTPTextFormat.TRACE_ID_KEY: 1220, - MockHTTPTextFormat.SPAN_ID_KEY: 7478, + _TRACE_ID_KEY: 1220, + _SPAN_ID_KEY: 7478, } ctx = self.shim.extract(opentracing.Format.HTTP_HEADERS, carrier) @@ -516,8 +528,8 @@ def test_extract_text_map(self): """Test `extract()` method for Format.TEXT_MAP.""" carrier = { - MockHTTPTextFormat.TRACE_ID_KEY: 1220, - MockHTTPTextFormat.SPAN_ID_KEY: 7478, + _TRACE_ID_KEY: 1220, + _SPAN_ID_KEY: 7478, } ctx = self.shim.extract(opentracing.Format.TEXT_MAP, carrier) @@ -532,25 +544,33 @@ def test_extract_binary(self): self.shim.extract(opentracing.Format.BINARY, bytearray()) -class MockHTTPTextFormat(HTTPTextFormat): - """Mock propagator for testing purposes.""" +_TRACE_ID_KEY = "mock-traceid" +_SPAN_ID_KEY = "mock-spanid" - TRACE_ID_KEY = "mock-traceid" - SPAN_ID_KEY = "mock-spanid" + +class MockHTTPExtractor(Extractor): + """Mock extractor for testing purposes.""" @classmethod - def extract(cls, get_from_carrier, carrier): - trace_id_list = get_from_carrier(carrier, cls.TRACE_ID_KEY) - span_id_list = get_from_carrier(carrier, cls.SPAN_ID_KEY) + def extract(cls, carrier, context=None, get_from_carrier=get_as_list): + trace_id_list = get_from_carrier(carrier, _TRACE_ID_KEY) + span_id_list = get_from_carrier(carrier, _SPAN_ID_KEY) if not trace_id_list or not span_id_list: - return trace.INVALID_SPAN_CONTEXT + return with_span_context(trace.INVALID_SPAN_CONTEXT) - return trace.SpanContext( - trace_id=int(trace_id_list[0]), span_id=int(span_id_list[0]) + return with_span_context( + trace.SpanContext( + trace_id=int(trace_id_list[0]), span_id=int(span_id_list[0]) + ) ) + +class MockHTTPInjector(Injector): + """Mock injector for testing purposes.""" + @classmethod - def inject(cls, context, set_in_carrier, carrier): - set_in_carrier(carrier, cls.TRACE_ID_KEY, str(context.trace_id)) - set_in_carrier(carrier, cls.SPAN_ID_KEY, str(context.span_id)) + def inject(cls, carrier, context=None, set_in_carrier=set_in_dict): + sc = span_context_from_context(context) + set_in_carrier(carrier, _TRACE_ID_KEY, str(sc.trace_id)) + set_in_carrier(carrier, _SPAN_ID_KEY, str(sc.span_id)) diff --git a/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py b/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py index 6581662d59..53cc337798 100644 --- a/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py +++ b/ext/opentelemetry-ext-wsgi/src/opentelemetry/ext/wsgi/__init__.py @@ -22,8 +22,9 @@ import typing import wsgiref.util as wsgiref_util -from opentelemetry import propagators, trace +from opentelemetry import propagation, trace from opentelemetry.ext.wsgi.version import __version__ +from opentelemetry.trace.propagation.context import span_context_from_context from opentelemetry.trace.status import Status, StatusCanonicalCode _HTTP_VERSION_PREFIX = "HTTP/" @@ -183,7 +184,9 @@ def __call__(self, environ, start_response): start_response: The WSGI start_response callable. """ - parent_span = propagators.extract(get_header_from_environ, environ) + propagation.extract(environ, get_from_carrier=get_header_from_environ) + + parent_span = span_context_from_context() span_name = get_default_span_name(environ) span = self.tracer.start_span( diff --git a/opentelemetry-api/src/opentelemetry/context/__init__.py b/opentelemetry-api/src/opentelemetry/context/__init__.py index 43a7722f88..f3a843f47b 100644 --- a/opentelemetry-api/src/opentelemetry/context/__init__.py +++ b/opentelemetry-api/src/opentelemetry/context/__init__.py @@ -138,15 +138,141 @@ async def main(): asyncio.run(main()) """ -from .base_context import BaseRuntimeContext +import threading +import typing +from contextlib import contextmanager -__all__ = ["Context"] +from .base_context import Context, Slot try: - from .async_context import AsyncRuntimeContext + from .async_context import ( + AsyncRuntimeContext, + ContextVarSlot, + ) - Context = AsyncRuntimeContext() # type: BaseRuntimeContext + _context_class = AsyncRuntimeContext # pylint: disable=invalid-name + _slot_class = ContextVarSlot # pylint: disable=invalid-name except ImportError: - from .thread_local_context import ThreadLocalRuntimeContext - - Context = ThreadLocalRuntimeContext() + from .thread_local_context import ( + ThreadLocalRuntimeContext, + ThreadLocalSlot, + ) + + _context_class = ThreadLocalRuntimeContext # pylint: disable=invalid-name + _slot_class = ThreadLocalSlot # pylint: disable=invalid-name + +_slots = {} # type: typing.Dict[str, 'Slot'] +_lock = threading.Lock() + + +def _register_slot(name: str, default: "object" = None) -> Slot: + """Register a context slot with an optional default value. + + :type name: str + :param name: The name of the context slot. + + :type default: object + :param name: The default value of the slot, can be a value or lambda. + + :returns: The registered slot. + """ + with _lock: + if name not in _slots: + _slots[name] = _slot_class(name, default) # type: Slot + return _slots[name] + + +def set_value( + name: str, val: "object", context: typing.Optional[Context] = None, +) -> Context: + """ + To record the local state of a cross-cutting concern, the + Context API provides a function which takes a context, a + key, and a value as input, and returns an updated context + which contains the new value. + + Args: + name: name of the entry to set + value: value of the entry to set + context: a context to copy, if None, the current context is used + """ + # Function inside the module that performs the action on the current context + # or in the passsed one based on the context object + if context: + ret = Context() + ret.snapshot = dict((n, v) for n, v in context.snapshot.items()) + ret.snapshot[name] = val + return ret + + # update value on current context: + slot = _register_slot(name) + slot.set(val) + return current() + + +def value(name: str, context: Context = None) -> typing.Optional["object"]: + """ + To access the local state of an concern, the Context API + provides a function which takes a context and a key as input, + and returns a value. + + Args: + name: name of the entry to retrieve + context: a context from which to retrieve the value, if None, the current context is used + """ + if context: + return context.value(name) + + # get context from current context + if name in _slots: + return _slots[name].get() + return None + + +def current() -> Context: + """ + To access the context associated with program execution, + the Context API provides a function which takes no arguments + and returns a Context. + """ + ret = Context() + for key, slot in _slots.items(): + ret.snapshot[key] = slot.get() + + return ret + + +def set_current(context: Context) -> None: + """ + To associate a context with program execution, the Context + API provides a function which takes a Context. + """ + _slots.clear() # remove current data + + for key, val in context.snapshot.items(): + slot = _register_slot(key) + slot.set(val) + + +@contextmanager +def use(**kwargs: typing.Dict[str, object]) -> typing.Iterator[None]: + snapshot = current() + for key in kwargs: + set_value(key, kwargs[key]) + yield + set_current(snapshot) + + +def new_context() -> Context: + return _context_class() + + +def merge_context_correlation(source: Context, dest: Context) -> Context: + ret = Context() + + for key in dest.snapshot: + ret.snapshot[key] = dest.snapshot[key] + + for key in source.snapshot: + ret.snapshot[key] = source.snapshot[key] + return ret diff --git a/opentelemetry-api/src/opentelemetry/context/async_context.py b/opentelemetry-api/src/opentelemetry/context/async_context.py index 267059fb31..fec66dce62 100644 --- a/opentelemetry-api/src/opentelemetry/context/async_context.py +++ b/opentelemetry-api/src/opentelemetry/context/async_context.py @@ -17,29 +17,54 @@ except ImportError: pass else: - import typing # pylint: disable=unused-import + # import contextvars + import typing from . import base_context - class AsyncRuntimeContext(base_context.BaseRuntimeContext): - class Slot(base_context.BaseRuntimeContext.Slot): - def __init__(self, name: str, default: object): - # pylint: disable=super-init-not-called - self.name = name - self.contextvar = ContextVar(name) # type: ContextVar[object] - self.default = base_context.wrap_callable( - default - ) # type: typing.Callable[..., object] - - def clear(self) -> None: - self.contextvar.set(self.default()) - - def get(self) -> object: - try: - return self.contextvar.get() - except LookupError: - value = self.default() - self.set(value) - return value - - def set(self, value: object) -> None: - self.contextvar.set(value) + class ContextVarSlot(base_context.Slot): + def __init__(self, name: str, default: object): + # pylint: disable=super-init-not-called + self.name = name + self.contextvar = ContextVar(name) # type: ContextVar[object] + self.default = base_context.wrap_callable( + default + ) # type: typing.Callable[..., object] + + def clear(self) -> None: + self.contextvar.set(self.default()) + + def get(self) -> object: + try: + return self.contextvar.get() + except LookupError: + value = self.default() + self.set(value) + return value + + def set(self, value: object) -> None: + self.contextvar.set(value) + + class AsyncRuntimeContext(base_context.Context): + def with_current_context( + self, func: typing.Callable[..., "object"] + ) -> typing.Callable[..., "object"]: + """Capture the current context and apply it to the provided func. + """ + + # TODO: implement this + # ctx = contextvars.copy_context() + # ctx.run() + # caller_context = self.current() + + # def call_with_current_context( + # *args: "object", **kwargs: "object" + # ) -> "object": + # try: + # backup_context = self.current() + # self.set_current(caller_context) + # # return ctx.run(func(*args, **kwargs)) + # return func(*args, **kwargs) + # finally: + # self.set_current(backup_context) + + # return call_with_current_context diff --git a/opentelemetry-api/src/opentelemetry/context/base_context.py b/opentelemetry-api/src/opentelemetry/context/base_context.py index 99d6869dd5..e0fa514bf9 100644 --- a/opentelemetry-api/src/opentelemetry/context/base_context.py +++ b/opentelemetry-api/src/opentelemetry/context/base_context.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading +import abc import typing -from contextlib import contextmanager def wrap_callable(target: "object") -> typing.Callable[[], object]: @@ -23,108 +22,27 @@ def wrap_callable(target: "object") -> typing.Callable[[], object]: return lambda: target -class BaseRuntimeContext: - class Slot: - def __init__(self, name: str, default: "object"): - raise NotImplementedError +class Context: + def __init__(self) -> None: + self.snapshot = {} - def clear(self) -> None: - raise NotImplementedError + def value(self, name): + return self.snapshot.get(name) - def get(self) -> "object": - raise NotImplementedError - def set(self, value: "object") -> None: - raise NotImplementedError +class Slot(abc.ABC): + @abc.abstractmethod + def __init__(self, name: str, default: "object"): + raise NotImplementedError - _lock = threading.Lock() - _slots = {} # type: typing.Dict[str, 'BaseRuntimeContext.Slot'] + @abc.abstractmethod + def clear(self) -> None: + raise NotImplementedError - @classmethod - def clear(cls) -> None: - """Clear all slots to their default value.""" - keys = cls._slots.keys() - for name in keys: - slot = cls._slots[name] - slot.clear() + @abc.abstractmethod + def get(self) -> "object": + raise NotImplementedError - @classmethod - def register_slot( - cls, name: str, default: "object" = None - ) -> "BaseRuntimeContext.Slot": - """Register a context slot with an optional default value. - - :type name: str - :param name: The name of the context slot. - - :type default: object - :param name: The default value of the slot, can be a value or lambda. - - :returns: The registered slot. - """ - with cls._lock: - if name not in cls._slots: - cls._slots[name] = cls.Slot(name, default) - return cls._slots[name] - - def apply(self, snapshot: typing.Dict[str, "object"]) -> None: - """Set the current context from a given snapshot dictionary""" - - for name in snapshot: - setattr(self, name, snapshot[name]) - - def snapshot(self) -> typing.Dict[str, "object"]: - """Return a dictionary of current slots by reference.""" - - keys = self._slots.keys() - return dict((n, self._slots[n].get()) for n in keys) - - def __repr__(self) -> str: - return "{}({})".format(type(self).__name__, self.snapshot()) - - def __getattr__(self, name: str) -> "object": - if name not in self._slots: - self.register_slot(name, None) - slot = self._slots[name] - return slot.get() - - def __setattr__(self, name: str, value: "object") -> None: - if name not in self._slots: - self.register_slot(name, None) - slot = self._slots[name] - slot.set(value) - - def __getitem__(self, name: str) -> "object": - return self.__getattr__(name) - - def __setitem__(self, name: str, value: "object") -> None: - self.__setattr__(name, value) - - @contextmanager # type: ignore - def use(self, **kwargs: typing.Dict[str, object]) -> typing.Iterator[None]: - snapshot = {key: self[key] for key in kwargs} - for key in kwargs: - self[key] = kwargs[key] - yield - for key in kwargs: - self[key] = snapshot[key] - - def with_current_context( - self, func: typing.Callable[..., "object"] - ) -> typing.Callable[..., "object"]: - """Capture the current context and apply it to the provided func. - """ - - caller_context = self.snapshot() - - def call_with_current_context( - *args: "object", **kwargs: "object" - ) -> "object": - try: - backup_context = self.snapshot() - self.apply(caller_context) - return func(*args, **kwargs) - finally: - self.apply(backup_context) - - return call_with_current_context + @abc.abstractmethod + def set(self, value: "object") -> None: + raise NotImplementedError diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py b/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py deleted file mode 100644 index 7f1a65882f..0000000000 --- a/opentelemetry-api/src/opentelemetry/context/propagation/binaryformat.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import typing - -from opentelemetry.trace import SpanContext - - -class BinaryFormat(abc.ABC): - """API for serialization of span context into binary formats. - - This class provides an interface that enables converting span contexts - to and from a binary format. - """ - - @staticmethod - @abc.abstractmethod - def to_bytes(context: SpanContext) -> bytes: - """Creates a byte representation of a SpanContext. - - to_bytes should read values from a SpanContext and return a data - format to represent it, in bytes. - - Args: - context: the SpanContext to serialize - - Returns: - A bytes representation of the SpanContext. - - """ - - @staticmethod - @abc.abstractmethod - def from_bytes(byte_representation: bytes) -> typing.Optional[SpanContext]: - """Return a SpanContext that was represented by bytes. - - from_bytes should return back a SpanContext that was constructed from - the data serialized in the byte_representation passed. If it is not - possible to read in a proper SpanContext, return None. - - Args: - byte_representation: the bytes to deserialize - - Returns: - A bytes representation of the SpanContext if it is valid. - Otherwise return None. - - """ diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py b/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py deleted file mode 100644 index 9b6098a9a4..0000000000 --- a/opentelemetry-api/src/opentelemetry/context/propagation/httptextformat.py +++ /dev/null @@ -1,115 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import typing - -from opentelemetry.trace import SpanContext - -_T = typing.TypeVar("_T") - -Setter = typing.Callable[[_T, str, str], None] -Getter = typing.Callable[[_T, str], typing.List[str]] - - -class HTTPTextFormat(abc.ABC): - """API for propagation of span context via headers. - - This class provides an interface that enables extracting and injecting - span context into headers of HTTP requests. HTTP frameworks and clients - can integrate with HTTPTextFormat by providing the object containing the - headers, and a getter and setter function for the extraction and - injection of values, respectively. - - Example:: - - import flask - import requests - from opentelemetry.context.propagation import HTTPTextFormat - - PROPAGATOR = HTTPTextFormat() - - - - def get_header_from_flask_request(request, key): - return request.headers.get_all(key) - - def set_header_into_requests_request(request: requests.Request, - key: str, value: str): - request.headers[key] = value - - def example_route(): - span_context = PROPAGATOR.extract( - get_header_from_flask_request, - flask.request - ) - request_to_downstream = requests.Request( - "GET", "http://httpbin.org/get" - ) - PROPAGATOR.inject( - span_context, - set_header_into_requests_request, - request_to_downstream - ) - session = requests.Session() - session.send(request_to_downstream.prepare()) - - - .. _Propagation API Specification: - https://github.com/open-telemetry/opentelemetry-specification/blob/master/specification/api-propagators.md - """ - - @abc.abstractmethod - def extract( - self, get_from_carrier: Getter[_T], carrier: _T - ) -> SpanContext: - """Create a SpanContext from values in the carrier. - - The extract function should retrieve values from the carrier - object using get_from_carrier, and use values to populate a - SpanContext value and return it. - - Args: - get_from_carrier: a function that can retrieve zero - or more values from the carrier. In the case that - the value does not exist, return an empty list. - carrier: and object which contains values that are - used to construct a SpanContext. This object - must be paired with an appropriate get_from_carrier - which understands how to extract a value from it. - Returns: - A SpanContext with configuration found in the carrier. - - """ - - @abc.abstractmethod - def inject( - self, context: SpanContext, set_in_carrier: Setter[_T], carrier: _T - ) -> None: - """Inject values from a SpanContext into a carrier. - - inject enables the propagation of values into HTTP clients or - other objects which perform an HTTP request. Implementations - should use the set_in_carrier method to set values on the - carrier. - - Args: - context: The SpanContext to read values from. - set_in_carrier: A setter function that can set values - on the carrier. - carrier: An object that a place to define HTTP headers. - Should be paired with set_in_carrier, which should - know how to set header values on the carrier. - - """ diff --git a/opentelemetry-api/src/opentelemetry/context/thread_local_context.py b/opentelemetry-api/src/opentelemetry/context/thread_local_context.py index b60914f846..83f7aa4806 100644 --- a/opentelemetry-api/src/opentelemetry/context/thread_local_context.py +++ b/opentelemetry-api/src/opentelemetry/context/thread_local_context.py @@ -13,33 +13,54 @@ # limitations under the License. import threading -import typing # pylint: disable=unused-import +import typing from . import base_context -class ThreadLocalRuntimeContext(base_context.BaseRuntimeContext): - class Slot(base_context.BaseRuntimeContext.Slot): - _thread_local = threading.local() +class ThreadLocalSlot(base_context.Slot): + _thread_local = threading.local() - def __init__(self, name: str, default: "object"): - # pylint: disable=super-init-not-called - self.name = name - self.default = base_context.wrap_callable( - default - ) # type: typing.Callable[..., object] + def __init__(self, name: str, default: "object"): + # pylint: disable=super-init-not-called + self.name = name + self.default = base_context.wrap_callable( + default + ) # type: typing.Callable[..., object] - def clear(self) -> None: - setattr(self._thread_local, self.name, self.default()) + def clear(self) -> None: + setattr(self._thread_local, self.name, self.default()) - def get(self) -> "object": - try: - got = getattr(self._thread_local, self.name) # type: object - return got - except AttributeError: - value = self.default() - self.set(value) - return value + def get(self) -> "object": + try: + got = getattr(self._thread_local, self.name) # type: object + return got + except AttributeError: + value = self.default() + self.set(value) + return value - def set(self, value: "object") -> None: - setattr(self._thread_local, self.name, value) + def set(self, value: "object") -> None: + setattr(self._thread_local, self.name, value) + + +class ThreadLocalRuntimeContext(base_context.Context): + def with_current_context( + self, func: typing.Callable[..., "object"] + ) -> typing.Callable[..., "object"]: + """Capture the current context and apply it to the provided func. + """ + # TODO: implement this + # caller_context = self.current() + + # def call_with_current_context( + # *args: "object", **kwargs: "object" + # ) -> "object": + # try: + # backup_context = self.current() + # self.set_current(caller_context) + # return func(*args, **kwargs) + # finally: + # self.set_current(backup_context) + + # return call_with_current_context diff --git a/opentelemetry-api/src/opentelemetry/distributedcontext/__init__.py b/opentelemetry-api/src/opentelemetry/correlationcontext/__init__.py similarity index 56% rename from opentelemetry-api/src/opentelemetry/distributedcontext/__init__.py rename to opentelemetry-api/src/opentelemetry/correlationcontext/__init__.py index 38ef3739b9..6451dedbc2 100644 --- a/opentelemetry-api/src/opentelemetry/distributedcontext/__init__.py +++ b/opentelemetry-api/src/opentelemetry/correlationcontext/__init__.py @@ -15,7 +15,10 @@ import itertools import string import typing -from contextlib import contextmanager +from typing import Optional + +from opentelemetry import context as ctx_api +from opentelemetry.propagation import Extractor, Injector PRINTABLE = frozenset( itertools.chain( @@ -24,8 +27,9 @@ ) +# TODO: are Entry* classes still needed here? class EntryMetadata: - """A class representing metadata of a DistributedContext entry + """A class representing metadata of a CorrelationContext entry Args: entry_ttl: The time to live (in service hops) of an entry. Must be @@ -41,7 +45,7 @@ def __init__(self, entry_ttl: int) -> None: class EntryKey(str): - """A class representing a key for a DistributedContext entry""" + """A class representing a key for a CorrelationContext entry""" def __new__(cls, value: str) -> "EntryKey": return cls.create(value) @@ -56,7 +60,7 @@ def create(value: str) -> "EntryKey": class EntryValue(str): - """A class representing the value of a DistributedContext entry""" + """A class representing the value of a CorrelationContext entry""" def __new__(cls, value: str) -> "EntryValue": return cls.create(value) @@ -78,48 +82,35 @@ def __init__( self.value = value -class DistributedContext: +# TODO: is CorrelationContext still needed here? +class CorrelationContext: """A container for distributed context entries""" - def __init__(self, entries: typing.Iterable[Entry]) -> None: - self._container = {entry.key: entry for entry in entries} - - def get_entries(self) -> typing.Iterable[Entry]: - """Returns an immutable iterator to entries.""" - return self._container.values() - - def get_entry_value(self, key: EntryKey) -> typing.Optional[EntryValue]: - """Returns the entry associated with a key or None - - Args: - key: the key with which to perform a lookup - """ - if key in self._container: - return self._container[key].value - return None - - -class DistributedContextManager: - def get_current_context(self) -> typing.Optional[DistributedContext]: - """Gets the current DistributedContext. - - Returns: - A DistributedContext instance representing the current context. - """ - - @contextmanager # type: ignore - def use_context( - self, context: DistributedContext - ) -> typing.Iterator[DistributedContext]: - """Context manager for controlling a DistributedContext lifetime. - - Set the context as the active DistributedContext. - - On exiting, the context manager will restore the parent - DistributedContext. - Args: - context: A DistributedContext instance to make current. - """ - # pylint: disable=no-self-use - yield context +class CorrelationContextManager: + @classmethod + def set_correlation( + cls, + key: str, + value: "object", + context: Optional[ctx_api.Context] = None, + ) -> ctx_api.Context: + return ctx_api.set_value(key, value, context=context) + + @classmethod + def correlation( + cls, key: str, context: Optional[ctx_api.Context] = None + ) -> "object": + return ctx_api.value(key, context=context) + + @classmethod + def remove_correlation( + cls, context: Optional[ctx_api.Context] = None + ) -> ctx_api.Context: + pass + + @classmethod + def clear_correlation( + cls, context: Optional[ctx_api.Context] = None + ) -> ctx_api.Context: + pass diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/correlationcontext/propagation/__init__.py similarity index 77% rename from opentelemetry-api/src/opentelemetry/context/propagation/__init__.py rename to opentelemetry-api/src/opentelemetry/correlationcontext/propagation/__init__.py index c8706281ad..f0fde58380 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/__init__.py +++ b/opentelemetry-api/src/opentelemetry/correlationcontext/propagation/__init__.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .binaryformat import BinaryFormat -from .httptextformat import HTTPTextFormat -__all__ = ["BinaryFormat", "HTTPTextFormat"] +class ContextKeys: + """ TODO """ + + KEY = "correlation-context" + + @classmethod + def span_context_key(cls) -> str: + """ TODO """ + return cls.KEY diff --git a/opentelemetry-api/src/opentelemetry/correlationcontext/propagation/context/__init__.py b/opentelemetry-api/src/opentelemetry/correlationcontext/propagation/context/__init__.py new file mode 100644 index 0000000000..0e4a881591 --- /dev/null +++ b/opentelemetry-api/src/opentelemetry/correlationcontext/propagation/context/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from opentelemetry.context import Context, set_value, value +from opentelemetry.correlationcontext import CorrelationContext +from opentelemetry.correlationcontext.propagation import ContextKeys + + +def correlation_context_from_context( + context: Optional[Context] = None, +) -> CorrelationContext: + return value(ContextKeys.span_context_key(), context=context) # type: ignore + + +def with_correlation_context( + correlation_context: CorrelationContext, context: Optional[Context] = None, +) -> Context: + return set_value( + ContextKeys.span_context_key(), correlation_context, context=context + ) diff --git a/opentelemetry-api/src/opentelemetry/distributedcontext/py.typed b/opentelemetry-api/src/opentelemetry/correlationcontext/py.typed similarity index 100% rename from opentelemetry-api/src/opentelemetry/distributedcontext/py.typed rename to opentelemetry-api/src/opentelemetry/correlationcontext/py.typed diff --git a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/binaryformat.py b/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/binaryformat.py deleted file mode 100644 index d6d083c0da..0000000000 --- a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/binaryformat.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import typing - -from opentelemetry.distributedcontext import DistributedContext - - -class BinaryFormat(abc.ABC): - """API for serialization of span context into binary formats. - - This class provides an interface that enables converting span contexts - to and from a binary format. - """ - - @staticmethod - @abc.abstractmethod - def to_bytes(context: DistributedContext) -> bytes: - """Creates a byte representation of a DistributedContext. - - to_bytes should read values from a DistributedContext and return a data - format to represent it, in bytes. - - Args: - context: the DistributedContext to serialize - - Returns: - A bytes representation of the DistributedContext. - - """ - - @staticmethod - @abc.abstractmethod - def from_bytes( - byte_representation: bytes, - ) -> typing.Optional[DistributedContext]: - """Return a DistributedContext that was represented by bytes. - - from_bytes should return back a DistributedContext that was constructed - from the data serialized in the byte_representation passed. If it is - not possible to read in a proper DistributedContext, return None. - - Args: - byte_representation: the bytes to deserialize - - Returns: - A bytes representation of the DistributedContext if it is valid. - Otherwise return None. - - """ diff --git a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/httptextformat.py b/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/httptextformat.py deleted file mode 100644 index 3e2c186283..0000000000 --- a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/httptextformat.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -import typing - -from opentelemetry.distributedcontext import DistributedContext - -Setter = typing.Callable[[object, str, str], None] -Getter = typing.Callable[[object, str], typing.List[str]] - - -class HTTPTextFormat(abc.ABC): - """API for propagation of span context via headers. - - This class provides an interface that enables extracting and injecting - span context into headers of HTTP requests. HTTP frameworks and clients - can integrate with HTTPTextFormat by providing the object containing the - headers, and a getter and setter function for the extraction and - injection of values, respectively. - - Example:: - - import flask - import requests - from opentelemetry.context.propagation import HTTPTextFormat - - PROPAGATOR = HTTPTextFormat() - - def get_header_from_flask_request(request, key): - return request.headers.get_all(key) - - def set_header_into_requests_request(request: requests.Request, - key: str, value: str): - request.headers[key] = value - - def example_route(): - distributed_context = PROPAGATOR.extract( - get_header_from_flask_request, - flask.request - ) - request_to_downstream = requests.Request( - "GET", "http://httpbin.org/get" - ) - PROPAGATOR.inject( - distributed_context, - set_header_into_requests_request, - request_to_downstream - ) - session = requests.Session() - session.send(request_to_downstream.prepare()) - - - .. _Propagation API Specification: - https://github.com/open-telemetry/opentelemetry-specification/blob/master/specification/api-propagators.md - """ - - @abc.abstractmethod - def extract( - self, get_from_carrier: Getter, carrier: object - ) -> DistributedContext: - """Create a DistributedContext from values in the carrier. - - The extract function should retrieve values from the carrier - object using get_from_carrier, and use values to populate a - DistributedContext value and return it. - - Args: - get_from_carrier: a function that can retrieve zero - or more values from the carrier. In the case that - the value does not exist, return an empty list. - carrier: and object which contains values that are - used to construct a DistributedContext. This object - must be paired with an appropriate get_from_carrier - which understands how to extract a value from it. - Returns: - A DistributedContext with configuration found in the carrier. - - """ - - @abc.abstractmethod - def inject( - self, - context: DistributedContext, - set_in_carrier: Setter, - carrier: object, - ) -> None: - """Inject values from a DistributedContext into a carrier. - - inject enables the propagation of values into HTTP clients or - other objects which perform an HTTP request. Implementations - should use the set_in_carrier method to set values on the - carrier. - - Args: - context: The DistributedContext to read values from. - set_in_carrier: A setter function that can set values - on the carrier. - carrier: An object that a place to define HTTP headers. - Should be paired with set_in_carrier, which should - know how to set header values on the carrier. - - """ diff --git a/opentelemetry-api/src/opentelemetry/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/propagation/__init__.py new file mode 100644 index 0000000000..fb8b5e39c3 --- /dev/null +++ b/opentelemetry-api/src/opentelemetry/propagation/__init__.py @@ -0,0 +1,284 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +The OpenTelemetry propagation module provides an interface that enables +propagation of Context details across any concerns which implement the +Extractor and Injector interfaces. + +Example:: + + import flask + import requests + from opentelemetry.propagation import DefaultExtractor, DefaultInjector + + extractor = DefaultExtractor() + injector = DefaultInjector() + + + def get_header_from_flask_request(request, key): + return request.headers.get_all(key) + + def set_header_into_requests_request(request: requests.Request, + key: str, value: str): + request.headers[key] = value + + def example_route(): + span_context = extractor.extract( + get_header_from_flask_request, + flask.request + ) + request_to_downstream = requests.Request( + "GET", "http://httpbin.org/get" + ) + injector.inject( + span_context, + set_header_into_requests_request, + request_to_downstream + ) + session = requests.Session() + session.send(request_to_downstream.prepare()) + + +.. _Propagation API Specification: + https://github.com/open-telemetry/opentelemetry-specification/blob/master/specification/api-propagators.md +""" +import abc +import typing + +import opentelemetry.trace as trace +from opentelemetry.context import Context, current + +ContextT = typing.TypeVar("ContextT") + +Setter = typing.Callable[[ContextT, str, str], None] +Getter = typing.Callable[[ContextT, str], typing.List[str]] + + +def get_as_list( + dict_object: typing.Dict[str, str], key: str +) -> typing.List[str]: + value = dict_object.get(key) + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +def set_in_dict( + dict_object: typing.Dict[str, str], key: str, value: str +) -> None: + dict_object[key] = value + + +class Extractor(abc.ABC): + @classmethod + @abc.abstractmethod + def extract( + cls, + carrier: ContextT, + context: typing.Optional[Context] = None, + get_from_carrier: typing.Optional[Getter[ContextT]] = get_as_list, + ) -> Context: + """Create a Context from values in the carrier. + + The extract function should retrieve values from the carrier + object using get_from_carrier, use values to populate a + Context value and return it. + + Args: + carrier: And object which contains values that are + used to construct a Context. This object + must be paired with an appropriate get_from_carrier + which understands how to extract a value from it. + context: The Context to set values into. + get_from_carrier: A function that can retrieve zero + or more values from the carrier. In the case that + the value does not exist, return an empty list. + Returns: + A Context with configuration found in the carrier. + """ + + +class Injector(abc.ABC): + @classmethod + @abc.abstractmethod + def inject( + cls, + carrier: ContextT, + context: typing.Optional[Context] = None, + set_in_carrier: typing.Optional[Setter[ContextT]] = set_in_dict, + ) -> None: + """Inject values from a Context into a carrier. + + inject enables the propagation of values into HTTP clients or + other objects which perform an HTTP request. Implementations + should use the set_in_carrier method to set values on the + carrier. + + Args: + carrier: An object that a place to define HTTP headers. + Should be paired with set_in_carrier, which should + know how to set header values on the carrier. + context: The Context to read values from. + set_in_carrier: A setter function that can set values + on the carrier. + """ + + +class DefaultExtractor(Extractor): + """The default Extractor that is used when no Extractor implementation is configured. + + All operations are no-ops. + """ + + @classmethod + def extract( + cls, + carrier: ContextT, + context: typing.Optional[Context] = None, + get_from_carrier: typing.Optional[Getter[ContextT]] = get_as_list, + ) -> Context: + if context: + return context + return current() + + +class DefaultInjector(Injector): + """The default Injector that is used when no Injector implementation is configured. + + All operations are no-ops. + """ + + @classmethod + def inject( + cls, + carrier: ContextT, + context: typing.Optional[Context] = None, + set_in_carrier: typing.Optional[Setter[ContextT]] = set_in_dict, + ) -> None: + return None + + +def extract( + carrier: ContextT, + context: typing.Optional[Context] = None, + extractors: typing.Optional[typing.List[Extractor]] = None, + get_from_carrier: typing.Optional[Getter[ContextT]] = get_as_list, +) -> typing.Optional[Context]: + """Load the Context from values in the carrier. + + Using the specified Extractor, the propagator will + extract a Context from the carrier. + + Args: + get_from_carrier: A function that can retrieve zero + or more values from the carrier. In the case that + the value does not exist, return an empty list. + carrier: An object which contains values that are + used to construct a SpanContext. This object + must be paired with an appropriate get_from_carrier + which understands how to extract a value from it. + """ + if context is None: + context = current() + if extractors is None: + extractors = get_http_extractors() + + for extractor in extractors: + # TODO: improve this + if get_from_carrier: + return extractor.extract( + context=context, + carrier=carrier, + get_from_carrier=get_from_carrier, + ) + return extractor.extract(context=context, carrier=carrier) + + return None + + +def inject( + carrier: ContextT, + injectors: typing.Optional[typing.List[Injector]] = None, + context: typing.Optional[Context] = None, + set_in_carrier: typing.Optional[Setter[ContextT]] = set_in_dict, +) -> None: + """Inject values from the current context into the carrier. + + inject enables the propagation of values into HTTP clients or + other objects which perform an HTTP request. Implementations + should use the set_in_carrier method to set values on the + carrier. + + Args: + set_in_carrier: A setter function that can set values + on the carrier. + carrier: An object that contains a representation of HTTP + headers. Should be paired with set_in_carrier, which + should know how to set header values on the carrier. + """ + if context is None: + context = current() + if injectors is None: + injectors = get_http_injectors() + + for injector in injectors: + injector.inject( + context=context, carrier=carrier, set_in_carrier=set_in_carrier + ) + + +_HTTP_TEXT_INJECTORS = [ + DefaultInjector +] # typing.List[httptextformat.Injector] + +_HTTP_TEXT_EXTRACTORS = [ + DefaultExtractor +] # typing.List[httptextformat.Extractor] + + +def set_http_extractors(extractor_list: typing.List[Extractor],) -> None: + """ + To update the global extractor, the Propagation API provides a + function which takes an extractor. + """ + global _HTTP_TEXT_EXTRACTORS # pylint:disable=global-statement + _HTTP_TEXT_EXTRACTORS = extractor_list # type: ignore + + +def set_http_injectors(injector_list: typing.List[Injector],) -> None: + """ + To update the global injector, the Propagation API provides a + function which takes an injector. + """ + global _HTTP_TEXT_INJECTORS # pylint:disable=global-statement + _HTTP_TEXT_INJECTORS = injector_list # type: ignore + + +def get_http_extractors() -> typing.List[Extractor]: + """ + To access the global extractor, the Propagation API provides + a function which returns an extractor. + """ + return _HTTP_TEXT_EXTRACTORS # type: ignore + + +def get_http_injectors() -> typing.List[Injector]: + """ + To access the global injector, the Propagation API provides a + function which returns an injector. + """ + return _HTTP_TEXT_INJECTORS # type: ignore diff --git a/opentelemetry-api/src/opentelemetry/propagators/__init__.py b/opentelemetry-api/src/opentelemetry/propagators/__init__.py deleted file mode 100644 index bb75d84c3a..0000000000 --- a/opentelemetry-api/src/opentelemetry/propagators/__init__.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import typing - -import opentelemetry.context.propagation.httptextformat as httptextformat -import opentelemetry.trace as trace -from opentelemetry.context.propagation.tracecontexthttptextformat import ( - TraceContextHTTPTextFormat, -) - -_T = typing.TypeVar("_T") - - -def extract( - get_from_carrier: httptextformat.Getter[_T], carrier: _T -) -> trace.SpanContext: - """Load the parent SpanContext from values in the carrier. - - Using the specified HTTPTextFormatter, the propagator will - extract a SpanContext from the carrier. If one is found, - it will be set as the parent context of the current span. - - Args: - get_from_carrier: a function that can retrieve zero - or more values from the carrier. In the case that - the value does not exist, return an empty list. - carrier: and object which contains values that are - used to construct a SpanContext. This object - must be paired with an appropriate get_from_carrier - which understands how to extract a value from it. - """ - return get_global_httptextformat().extract(get_from_carrier, carrier) - - -def inject( - tracer: trace.Tracer, - set_in_carrier: httptextformat.Setter[_T], - carrier: _T, -) -> None: - """Inject values from the current context into the carrier. - - inject enables the propagation of values into HTTP clients or - other objects which perform an HTTP request. Implementations - should use the set_in_carrier method to set values on the - carrier. - - Args: - set_in_carrier: A setter function that can set values - on the carrier. - carrier: An object that contains a representation of HTTP - headers. Should be paired with set_in_carrier, which - should know how to set header values on the carrier. - """ - get_global_httptextformat().inject( - tracer.get_current_span().get_context(), set_in_carrier, carrier - ) - - -_HTTP_TEXT_FORMAT = ( - TraceContextHTTPTextFormat() -) # type: httptextformat.HTTPTextFormat - - -def get_global_httptextformat() -> httptextformat.HTTPTextFormat: - return _HTTP_TEXT_FORMAT - - -def set_global_httptextformat( - http_text_format: httptextformat.HTTPTextFormat, -) -> None: - global _HTTP_TEXT_FORMAT # pylint:disable=global-statement - _HTTP_TEXT_FORMAT = http_text_format diff --git a/opentelemetry-api/src/opentelemetry/trace/__init__.py b/opentelemetry-api/src/opentelemetry/trace/__init__.py index e426d11a1a..af5c82e0b1 100644 --- a/opentelemetry-api/src/opentelemetry/trace/__init__.py +++ b/opentelemetry-api/src/opentelemetry/trace/__init__.py @@ -71,6 +71,7 @@ import typing from contextlib import contextmanager +from opentelemetry.context import Context from opentelemetry.trace.status import Status from opentelemetry.util import loader, types @@ -416,7 +417,9 @@ class Tracer: # This is the default behavior when creating spans. CURRENT_SPAN = Span() - def get_current_span(self) -> "Span": + def get_current_span( + self, context: typing.Optional[Context] = None + ) -> "Span": """Gets the currently active span from the context. If there is no current span, return a placeholder span with an invalid @@ -426,7 +429,7 @@ def get_current_span(self) -> "Span": The currently active :class:`.Span`, or a placeholder span with an invalid :class:`.SpanContext`. """ - # pylint: disable=no-self-use + # pylint: disable=unused-argument,no-self-use return INVALID_SPAN def start_span( @@ -438,6 +441,7 @@ def start_span( links: typing.Sequence[Link] = (), start_time: typing.Optional[int] = None, set_status_on_exception: bool = True, + context: typing.Optional[Context] = None, ) -> "Span": """Starts a span. @@ -489,6 +493,7 @@ def start_as_current_span( kind: SpanKind = SpanKind.INTERNAL, attributes: typing.Optional[types.Attributes] = None, links: typing.Sequence[Link] = (), + context: typing.Optional[Context] = None, ) -> typing.Iterator["Span"]: """Context manager for creating a new span and set it as the current span in this tracer's context. diff --git a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/trace/propagation/__init__.py similarity index 60% rename from opentelemetry-api/src/opentelemetry/distributedcontext/propagation/__init__.py rename to opentelemetry-api/src/opentelemetry/trace/propagation/__init__.py index c8706281ad..67a74dbac5 100644 --- a/opentelemetry-api/src/opentelemetry/distributedcontext/propagation/__init__.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/__init__.py @@ -12,7 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .binaryformat import BinaryFormat -from .httptextformat import HTTPTextFormat -__all__ = ["BinaryFormat", "HTTPTextFormat"] +class ContextKeys: + """ TODO """ + + EXTRACT_SPAN_CONTEXT_KEY = "extracted-span-context" + SPAN_KEY = "current-span" + + @classmethod + def span_context_key(cls) -> str: + """ Returns key for a SpanContext """ + return cls.EXTRACT_SPAN_CONTEXT_KEY + + @classmethod + def span_key(cls) -> str: + """ Returns key for a Span """ + return cls.SPAN_KEY diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/context/__init__.py b/opentelemetry-api/src/opentelemetry/trace/propagation/context/__init__.py new file mode 100644 index 0000000000..961f3fd9b5 --- /dev/null +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/context/__init__.py @@ -0,0 +1,49 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional + +from opentelemetry import context as ctx_api +from opentelemetry.trace import INVALID_SPAN_CONTEXT, Span, SpanContext +from opentelemetry.trace.propagation import ContextKeys + + +def span_context_from_context( + context: Optional[ctx_api.Context] = None, +) -> SpanContext: + span = span_from_context(context=context) + if span: + return span.get_context() + sc = ctx_api.value(ContextKeys.span_context_key(), context=context) # type: ignore + if sc: + return sc + + return INVALID_SPAN_CONTEXT + + +def with_span_context( + span_context: SpanContext, context: Optional[ctx_api.Context] = None +) -> ctx_api.Context: + return ctx_api.set_value( + ContextKeys.span_context_key(), span_context, context=context + ) + + +def span_from_context(context: Optional[ctx_api.Context] = None) -> Span: + return ctx_api.value(ContextKeys.span_key(), context=context) # type: ignore + + +def with_span( + span: Span, context: Optional[ctx_api.Context] = None +) -> ctx_api.Context: + return ctx_api.set_value(ContextKeys.span_key(), span, context=context) diff --git a/opentelemetry-api/tests/context/__init__.py b/opentelemetry-api/tests/context/__init__.py index e69de29bb2..d853a7bcf6 100644 --- a/opentelemetry-api/tests/context/__init__.py +++ b/opentelemetry-api/tests/context/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/opentelemetry-api/tests/context/test_context.py b/opentelemetry-api/tests/context/test_context.py new file mode 100644 index 0000000000..77ca794345 --- /dev/null +++ b/opentelemetry-api/tests/context/test_context.py @@ -0,0 +1,87 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from opentelemetry import context + + +def do_work(): + context.set_value("say-something", "bar") + + +class TestContext(unittest.TestCase): + def test_context(self): + self.assertIsNone(context.current().value("say-something")) + empty_context = context.current() + context.set_value("say-something", "foo") + self.assertEqual(context.current().value("say-something"), "foo") + second_context = context.current() + + do_work() + self.assertEqual(context.current().value("say-something"), "bar") + third_context = context.current() + + self.assertIsNone(empty_context.value("say-something")) + self.assertEqual(second_context.value("say-something"), "foo") + self.assertEqual(third_context.value("say-something"), "bar") + + def test_merge(self): + context.set_value("name", "first") + context.set_value("somebool", True) + context.set_value("key", "value") + context.set_value("otherkey", "othervalue") + src_ctx = context.current() + + context.set_value("name", "second") + context.set_value("somebool", False) + context.set_value("anotherkey", "anothervalue") + dst_ctx = context.current() + + context.set_current( + context.merge_context_correlation(src_ctx, dst_ctx) + ) + current = context.current() + self.assertEqual(current.value("name"), "first") + self.assertTrue(current.value("somebool")) + self.assertEqual(current.value("key"), "value") + self.assertEqual(current.value("otherkey"), "othervalue") + self.assertEqual(current.value("anotherkey"), "anothervalue") + + def test_propagation(self): + pass + + def test_restore_context_on_exit(self): + context.set_current(context.new_context()) + context.set_value("a", "xxx") + context.set_value("b", "yyy") + + self.assertEqual({"a": "xxx", "b": "yyy"}, context.current().snapshot) + with context.use(a="foo"): + self.assertEqual( + {"a": "foo", "b": "yyy"}, context.current().snapshot + ) + context.set_value("a", "i_want_to_mess_it_but_wont_work") + context.set_value("b", "i_want_to_mess_it") + self.assertEqual({"a": "xxx", "b": "yyy"}, context.current().snapshot) + + def test_set_value(self): + first = context.set_value("a", "yyy") + second = context.set_value("a", "zzz") + third = context.set_value("a", "---", first) + current_context = context.current() + self.assertEqual("yyy", context.value("a", context=first)) + self.assertEqual("zzz", context.value("a", context=second)) + self.assertEqual("---", context.value("a", context=third)) + self.assertEqual("zzz", context.value("a", context=current_context)) diff --git a/opentelemetry-api/tests/distributedcontext/__init__.py b/opentelemetry-api/tests/correlationcontext/__init__.py similarity index 100% rename from opentelemetry-api/tests/distributedcontext/__init__.py rename to opentelemetry-api/tests/correlationcontext/__init__.py diff --git a/opentelemetry-api/tests/correlationcontext/test_distributed_context.py b/opentelemetry-api/tests/correlationcontext/test_distributed_context.py new file mode 100644 index 0000000000..0ea05e2c9b --- /dev/null +++ b/opentelemetry-api/tests/correlationcontext/test_distributed_context.py @@ -0,0 +1,122 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from opentelemetry import correlationcontext + + +class TestEntryMetadata(unittest.TestCase): + def test_entry_ttl_no_propagation(self): + metadata = correlationcontext.EntryMetadata( + correlationcontext.EntryMetadata.NO_PROPAGATION + ) + self.assertEqual(metadata.entry_ttl, 0) + + def test_entry_ttl_unlimited_propagation(self): + metadata = correlationcontext.EntryMetadata( + correlationcontext.EntryMetadata.UNLIMITED_PROPAGATION + ) + self.assertEqual(metadata.entry_ttl, -1) + + +class TestEntryKey(unittest.TestCase): + def test_create_empty(self): + with self.assertRaises(ValueError): + correlationcontext.EntryKey.create("") + + def test_create_too_long(self): + with self.assertRaises(ValueError): + correlationcontext.EntryKey.create("a" * 256) + + def test_create_invalid_character(self): + with self.assertRaises(ValueError): + correlationcontext.EntryKey.create("\x00") + + def test_create_valid(self): + key = correlationcontext.EntryKey.create("ok") + self.assertEqual(key, "ok") + + def test_key_new(self): + key = correlationcontext.EntryKey("ok") + self.assertEqual(key, "ok") + + +class TestEntryValue(unittest.TestCase): + def test_create_invalid_character(self): + with self.assertRaises(ValueError): + correlationcontext.EntryValue.create("\x00") + + def test_create_valid(self): + key = correlationcontext.EntryValue.create("ok") + self.assertEqual(key, "ok") + + def test_key_new(self): + key = correlationcontext.EntryValue("ok") + self.assertEqual(key, "ok") + + +# TODO:replace these +# class TestCorrelationContext(unittest.TestCase): +# def setUp(self): +# self.entry = correlationcontext.Entry( +# correlationcontext.EntryMetadata( +# correlationcontext.EntryMetadata.NO_PROPAGATION +# ), +# correlationcontext.EntryKey("key"), +# correlationcontext.EntryValue("value"), +# ) +# self.context = with_correlation_context( +# CorrelationContext(entries=[self.entry]) +# ) + +# def test_get_entries(self): +# self.assertIn( +# self.entry, correlation_context_from_context(self.context).get_entries(), +# ) + +# def test_get_entry_value_present(self): +# value = correlationcontext.CorrelationContext.get_entry_value( +# self.context, self.entry.key +# ) +# self.assertIs(value, self.entry.value) + +# def test_get_entry_value_missing(self): +# key = correlationcontext.EntryKey("missing") +# value = correlationcontext.CorrelationContext.get_entry_value( +# self.context, key +# ) +# self.assertIsNone(value) + + +# TODO:replace these +# class TestCorrelationContextManager(unittest.TestCase): +# def setUp(self): +# self.manager = correlationcontext.CorrelationContextManager() + +# def test_current_context(self): +# self.assertIsNone(self.manager.current_context()) + +# def test_use_context(self): +# expected = correlationcontext.CorrelationContext( +# ( +# correlationcontext.Entry( +# correlationcontext.EntryMetadata(0), +# correlationcontext.EntryKey("0"), +# correlationcontext.EntryValue(""), +# ), +# ) +# ) +# with self.manager.use_context(expected) as output: +# self.assertIs(output, expected) diff --git a/opentelemetry-api/tests/distributedcontext/test_distributed_context.py b/opentelemetry-api/tests/distributedcontext/test_distributed_context.py deleted file mode 100644 index c730603b16..0000000000 --- a/opentelemetry-api/tests/distributedcontext/test_distributed_context.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from opentelemetry import distributedcontext - - -class TestEntryMetadata(unittest.TestCase): - def test_entry_ttl_no_propagation(self): - metadata = distributedcontext.EntryMetadata( - distributedcontext.EntryMetadata.NO_PROPAGATION - ) - self.assertEqual(metadata.entry_ttl, 0) - - def test_entry_ttl_unlimited_propagation(self): - metadata = distributedcontext.EntryMetadata( - distributedcontext.EntryMetadata.UNLIMITED_PROPAGATION - ) - self.assertEqual(metadata.entry_ttl, -1) - - -class TestEntryKey(unittest.TestCase): - def test_create_empty(self): - with self.assertRaises(ValueError): - distributedcontext.EntryKey.create("") - - def test_create_too_long(self): - with self.assertRaises(ValueError): - distributedcontext.EntryKey.create("a" * 256) - - def test_create_invalid_character(self): - with self.assertRaises(ValueError): - distributedcontext.EntryKey.create("\x00") - - def test_create_valid(self): - key = distributedcontext.EntryKey.create("ok") - self.assertEqual(key, "ok") - - def test_key_new(self): - key = distributedcontext.EntryKey("ok") - self.assertEqual(key, "ok") - - -class TestEntryValue(unittest.TestCase): - def test_create_invalid_character(self): - with self.assertRaises(ValueError): - distributedcontext.EntryValue.create("\x00") - - def test_create_valid(self): - key = distributedcontext.EntryValue.create("ok") - self.assertEqual(key, "ok") - - def test_key_new(self): - key = distributedcontext.EntryValue("ok") - self.assertEqual(key, "ok") - - -class TestDistributedContext(unittest.TestCase): - def setUp(self): - entry = self.entry = distributedcontext.Entry( - distributedcontext.EntryMetadata( - distributedcontext.EntryMetadata.NO_PROPAGATION - ), - distributedcontext.EntryKey("key"), - distributedcontext.EntryValue("value"), - ) - self.context = distributedcontext.DistributedContext((entry,)) - - def test_get_entries(self): - self.assertIn(self.entry, self.context.get_entries()) - - def test_get_entry_value_present(self): - value = self.context.get_entry_value(self.entry.key) - self.assertIs(value, self.entry.value) - - def test_get_entry_value_missing(self): - key = distributedcontext.EntryKey("missing") - value = self.context.get_entry_value(key) - self.assertIsNone(value) - - -class TestDistributedContextManager(unittest.TestCase): - def setUp(self): - self.manager = distributedcontext.DistributedContextManager() - - def test_get_current_context(self): - self.assertIsNone(self.manager.get_current_context()) - - def test_use_context(self): - expected = distributedcontext.DistributedContext( - ( - distributedcontext.Entry( - distributedcontext.EntryMetadata(0), - distributedcontext.EntryKey("0"), - distributedcontext.EntryValue(""), - ), - ) - ) - with self.manager.use_context(expected) as output: - self.assertIs(output, expected) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py index 7d59fddb9e..3ad69a1d29 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/b3_format.py @@ -15,24 +15,50 @@ import typing import opentelemetry.trace as trace -from opentelemetry.context.propagation.httptextformat import HTTPTextFormat +from opentelemetry.context import Context +from opentelemetry.propagation import ( + Extractor, + Getter, + Injector, + Setter, + get_as_list, + set_in_dict, +) +from opentelemetry.trace.propagation.context import ( + span_context_from_context, + with_span_context, +) +_T = typing.TypeVar("_T") + +TRACE_ID_KEY = "x-b3-traceid" +SPAN_ID_KEY = "x-b3-spanid" +SAMPLED_KEY = "x-b3-sampled" + + +def http_propagator() -> typing.Tuple[Extractor, Injector]: + """ TODO """ + return B3Extractor, B3Injector -class B3Format(HTTPTextFormat): + +class B3Extractor(Extractor): """Propagator for the B3 HTTP header format. See: https://github.com/openzipkin/b3-propagation """ SINGLE_HEADER_KEY = "b3" - TRACE_ID_KEY = "x-b3-traceid" - SPAN_ID_KEY = "x-b3-spanid" - SAMPLED_KEY = "x-b3-sampled" FLAGS_KEY = "x-b3-flags" _SAMPLE_PROPAGATE_VALUES = set(["1", "True", "true", "d"]) @classmethod - def extract(cls, get_from_carrier, carrier): + def extract( + cls, + carrier, + context: typing.Optional[Context] = None, + get_from_carrier: typing.Optional[Getter[_T]] = get_as_list, + ): + trace_id = format_trace_id(trace.INVALID_TRACE_ID) span_id = format_span_id(trace.INVALID_SPAN_ID) sampled = "0" @@ -57,24 +83,18 @@ def extract(cls, get_from_carrier, carrier): elif len(fields) == 4: trace_id, span_id, sampled, _parent_span_id = fields else: - return trace.INVALID_SPAN_CONTEXT + return with_span_context(trace.INVALID_SPAN_CONTEXT) else: trace_id = ( - _extract_first_element( - get_from_carrier(carrier, cls.TRACE_ID_KEY) - ) + _extract_first_element(get_from_carrier(carrier, TRACE_ID_KEY)) or trace_id ) span_id = ( - _extract_first_element( - get_from_carrier(carrier, cls.SPAN_ID_KEY) - ) + _extract_first_element(get_from_carrier(carrier, SPAN_ID_KEY)) or span_id ) sampled = ( - _extract_first_element( - get_from_carrier(carrier, cls.SAMPLED_KEY) - ) + _extract_first_element(get_from_carrier(carrier, SAMPLED_KEY)) or sampled ) flags = ( @@ -91,24 +111,31 @@ def extract(cls, get_from_carrier, carrier): # header is set to allow. if sampled in cls._SAMPLE_PROPAGATE_VALUES or flags == "1": options |= trace.TraceOptions.SAMPLED - return trace.SpanContext( - # trace an span ids are encoded in hex, so must be converted - trace_id=int(trace_id, 16), - span_id=int(span_id, 16), - trace_options=trace.TraceOptions(options), - trace_state=trace.TraceState(), + + return with_span_context( + trace.SpanContext( + # trace an span ids are encoded in hex, so must be converted + trace_id=int(trace_id, 16), + span_id=int(span_id, 16), + trace_options=trace.TraceOptions(options), + trace_state=trace.TraceState(), + ), ) + +class B3Injector(Injector): @classmethod - def inject(cls, context, set_in_carrier, carrier): - sampled = (trace.TraceOptions.SAMPLED & context.trace_options) != 0 - set_in_carrier( - carrier, cls.TRACE_ID_KEY, format_trace_id(context.trace_id) - ) - set_in_carrier( - carrier, cls.SPAN_ID_KEY, format_span_id(context.span_id) - ) - set_in_carrier(carrier, cls.SAMPLED_KEY, "1" if sampled else "0") + def inject( + cls, + carrier, + context: typing.Optional[Context] = None, + set_in_carrier: typing.Optional[Setter[_T]] = set_in_dict, + ): + sc = span_context_from_context(context) + sampled = (trace.TraceOptions.SAMPLED & sc.trace_options) != 0 + set_in_carrier(carrier, TRACE_ID_KEY, format_trace_id(sc.trace_id)) + set_in_carrier(carrier, SPAN_ID_KEY, format_span_id(sc.span_id)) + set_in_carrier(carrier, SAMPLED_KEY, "1" if sampled else "0") def format_trace_id(trace_id: int) -> str: diff --git a/opentelemetry-api/src/opentelemetry/context/propagation/tracecontexthttptextformat.py b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/tracecontexthttptextformat.py similarity index 66% rename from opentelemetry-api/src/opentelemetry/context/propagation/tracecontexthttptextformat.py rename to opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/tracecontexthttptextformat.py index 5d00632ed1..11ae57621f 100644 --- a/opentelemetry-api/src/opentelemetry/context/propagation/tracecontexthttptextformat.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/context/propagation/tracecontexthttptextformat.py @@ -16,7 +16,19 @@ import typing import opentelemetry.trace as trace -from opentelemetry.context.propagation import httptextformat +from opentelemetry.context import Context +from opentelemetry.propagation import ( + Extractor, + Getter, + Injector, + Setter, + get_as_list, + set_in_dict, +) +from opentelemetry.trace.propagation.context import ( + span_context_from_context, + with_span_context, +) _T = typing.TypeVar("_T") @@ -47,12 +59,19 @@ _TRACECONTEXT_MAXIMUM_TRACESTATE_KEYS = 32 -class TraceContextHTTPTextFormat(httptextformat.HTTPTextFormat): - """Extracts and injects using w3c TraceContext's headers. +TRACEPARENT_HEADER_NAME = "traceparent" +TRACESTATE_HEADER_NAME = "tracestate" + + +def http_propagator() -> typing.Tuple[Extractor, Injector]: + """ TODO """ + return TraceContextHTTPExtractor, TraceContextHTTPInjector + + +class TraceContextHTTPExtractor(Extractor): + """Extracts using w3c TraceContext's headers. """ - _TRACEPARENT_HEADER_NAME = "traceparent" - _TRACESTATE_HEADER_NAME = "tracestate" _TRACEPARENT_HEADER_FORMAT = ( "^[ \t]*([0-9a-f]{2})-([0-9a-f]{32})-([0-9a-f]{16})-([0-9a-f]{2})" + "(-.*)?[ \t]*$" @@ -61,18 +80,21 @@ class TraceContextHTTPTextFormat(httptextformat.HTTPTextFormat): @classmethod def extract( - cls, get_from_carrier: httptextformat.Getter[_T], carrier: _T - ) -> trace.SpanContext: + cls, + carrier: _T, + context: typing.Optional[Context] = None, + get_from_carrier: typing.Optional[Getter[_T]] = get_as_list, + ) -> Context: """Extracts a valid SpanContext from the carrier. """ - header = get_from_carrier(carrier, cls._TRACEPARENT_HEADER_NAME) + header = get_from_carrier(carrier, TRACEPARENT_HEADER_NAME) if not header: - return trace.INVALID_SPAN_CONTEXT + return with_span_context(trace.INVALID_SPAN_CONTEXT) match = re.search(cls._TRACEPARENT_HEADER_FORMAT_RE, header[0]) if not match: - return trace.INVALID_SPAN_CONTEXT + return with_span_context(trace.INVALID_SPAN_CONTEXT) version = match.group(1) trace_id = match.group(2) @@ -80,48 +102,55 @@ def extract( trace_options = match.group(4) if trace_id == "0" * 32 or span_id == "0" * 16: - return trace.INVALID_SPAN_CONTEXT + return with_span_context(trace.INVALID_SPAN_CONTEXT) if version == "00": if match.group(5): - return trace.INVALID_SPAN_CONTEXT + return with_span_context(trace.INVALID_SPAN_CONTEXT) if version == "ff": - return trace.INVALID_SPAN_CONTEXT + return with_span_context(trace.INVALID_SPAN_CONTEXT) - tracestate_headers = get_from_carrier( - carrier, cls._TRACESTATE_HEADER_NAME - ) + tracestate_headers = get_from_carrier(carrier, TRACESTATE_HEADER_NAME) tracestate = _parse_tracestate(tracestate_headers) - span_context = trace.SpanContext( - trace_id=int(trace_id, 16), - span_id=int(span_id, 16), - trace_options=trace.TraceOptions(trace_options), - trace_state=tracestate, + return with_span_context( + trace.SpanContext( + trace_id=int(trace_id, 16), + span_id=int(span_id, 16), + trace_options=trace.TraceOptions(trace_options), + trace_state=tracestate, + ) ) - return span_context + +class TraceContextHTTPInjector(Injector): + """Injects using w3c TraceContext's headers. + """ @classmethod def inject( cls, - context: trace.SpanContext, - set_in_carrier: httptextformat.Setter[_T], carrier: _T, + context: typing.Optional[Context] = None, + set_in_carrier: typing.Optional[Setter[_T]] = set_in_dict, ) -> None: - if context == trace.INVALID_SPAN_CONTEXT: + sc = span_context_from_context(context) + if sc is None or sc == trace.INVALID_SPAN_CONTEXT: + return + + if ( + sc.trace_id == trace.INVALID_TRACE_ID + or sc.span_id == trace.INVALID_SPAN_ID + ): return + traceparent_string = "00-{:032x}-{:016x}-{:02x}".format( - context.trace_id, context.span_id, context.trace_options + sc.trace_id, sc.span_id, sc.trace_options, ) - set_in_carrier( - carrier, cls._TRACEPARENT_HEADER_NAME, traceparent_string - ) - if context.trace_state: - tracestate_string = _format_tracestate(context.trace_state) - set_in_carrier( - carrier, cls._TRACESTATE_HEADER_NAME, tracestate_string - ) + set_in_carrier(carrier, TRACEPARENT_HEADER_NAME, traceparent_string) + if sc.trace_state: + tracestate_string = _format_tracestate(sc.trace_state) + set_in_carrier(carrier, TRACESTATE_HEADER_NAME, tracestate_string) def _parse_tracestate(header_list: typing.List[str]) -> trace.TraceState: diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/correlationcontext/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/correlationcontext/__init__.py new file mode 100644 index 0000000000..06473fa4b5 --- /dev/null +++ b/opentelemetry-sdk/src/opentelemetry/sdk/correlationcontext/__init__.py @@ -0,0 +1,62 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing +from contextlib import contextmanager + +from opentelemetry import correlationcontext as cctx_api +from opentelemetry.context import current, set_value, value + + +class CorrelationContextManager(cctx_api.CorrelationContextManager): + """See `opentelemetry.correlationcontext.CorrelationContextManager` + + Args: + name: The name of the context manager + """ + + def __init__(self, name: str = "") -> None: + if name: + self.slot_name = "CorrelationContext.{}".format(name) + else: + self.slot_name = "CorrelationContext" + + def current_context(self,) -> typing.Optional[cctx_api.CorrelationContext]: + """Gets the current CorrelationContext. + + Returns: + A CorrelationContext instance representing the current context. + """ + return value(self.slot_name) + + @contextmanager + def use_context( + self, context: cctx_api.CorrelationContext + ) -> typing.Iterator[cctx_api.CorrelationContext]: + """Context manager for controlling a CorrelationContext lifetime. + + Set the context as the active CorrelationContext. + + On exiting, the context manager will restore the parent + CorrelationContext. + + Args: + context: A CorrelationContext instance to make current. + """ + snapshot = current().value(self.slot_name) + set_value(self.slot_name, context) + try: + yield context + finally: + set_value(self.slot_name, snapshot) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/distributedcontext/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/distributedcontext/__init__.py deleted file mode 100644 index a20cbf8963..0000000000 --- a/opentelemetry-sdk/src/opentelemetry/sdk/distributedcontext/__init__.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2019, OpenTelemetry Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import typing -from contextlib import contextmanager - -from opentelemetry import distributedcontext as dctx_api -from opentelemetry.context import Context - - -class DistributedContextManager(dctx_api.DistributedContextManager): - """See `opentelemetry.distributedcontext.DistributedContextManager` - - Args: - name: The name of the context manager - """ - - def __init__(self, name: str = "") -> None: - if name: - slot_name = "DistributedContext.{}".format(name) - else: - slot_name = "DistributedContext" - - self._current_context = Context.register_slot(slot_name) - - def get_current_context( - self, - ) -> typing.Optional[dctx_api.DistributedContext]: - """Gets the current DistributedContext. - - Returns: - A DistributedContext instance representing the current context. - """ - return self._current_context.get() - - @contextmanager - def use_context( - self, context: dctx_api.DistributedContext - ) -> typing.Iterator[dctx_api.DistributedContext]: - """Context manager for controlling a DistributedContext lifetime. - - Set the context as the active DistributedContext. - - On exiting, the context manager will restore the parent - DistributedContext. - - Args: - context: A DistributedContext instance to make current. - """ - snapshot = self._current_context.get() - self._current_context.set(context) - try: - yield context - finally: - self._current_context.set(snapshot) diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/distributedcontext/propagation/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/distributedcontext/propagation/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py index 3035ae7ef9..2e33f22277 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py @@ -21,11 +21,18 @@ from types import TracebackType from typing import Iterator, Optional, Sequence, Tuple, Type +from opentelemetry import context as ctx_api from opentelemetry import trace as trace_api -from opentelemetry.context import Context from opentelemetry.sdk import util from opentelemetry.sdk.util import BoundedDict, BoundedList from opentelemetry.trace import SpanContext, sampling +from opentelemetry.trace.propagation.context import ( + ContextKeys, + span_context_from_context, + span_from_context, + with_span, + with_span_context, +) from opentelemetry.trace.status import Status, StatusCanonicalCode from opentelemetry.util import time_ns, types @@ -385,9 +392,9 @@ def __init__( self.source = source self.instrumentation_info = instrumentation_info - def get_current_span(self): + def get_current_span(self, context: Optional[ctx_api.Context] = None): """See `opentelemetry.trace.Tracer.get_current_span`.""" - return self.source.get_current_span() + return span_from_context(context=context) def start_as_current_span( self, @@ -396,10 +403,13 @@ def start_as_current_span( kind: trace_api.SpanKind = trace_api.SpanKind.INTERNAL, attributes: Optional[types.Attributes] = None, links: Sequence[trace_api.Link] = (), + context: Optional[ctx_api.Context] = None, ) -> Iterator[trace_api.Span]: """See `opentelemetry.trace.Tracer.start_as_current_span`.""" - span = self.start_span(name, parent, kind, attributes, links) + span = self.start_span( + name, parent, kind, attributes, links, context=context + ) return self.use_span(span, end_on_exit=True) def start_span( # pylint: disable=too-many-locals @@ -411,16 +421,20 @@ def start_span( # pylint: disable=too-many-locals links: Sequence[trace_api.Link] = (), start_time: Optional[int] = None, set_status_on_exception: bool = True, + context: Optional[ctx_api.Context] = None, ) -> trace_api.Span: """See `opentelemetry.trace.Tracer.start_span`.""" if parent is Tracer.CURRENT_SPAN: - parent = self.get_current_span() + parent = self.get_current_span(context=context) parent_context = parent if isinstance(parent_context, trace_api.Span): parent_context = parent.get_context() + if parent_context is None: + parent_context = span_context_from_context(context) + if parent_context is not None and not isinstance( parent_context, trace_api.SpanContext ): @@ -484,16 +498,13 @@ def use_span( ) -> Iterator[trace_api.Span]: """See `opentelemetry.trace.Tracer.use_span`.""" try: - span_snapshot = self.source.get_current_span() - self.source._current_span_slot.set( # pylint:disable=protected-access - span - ) + span_snapshot = span_from_context() + with_span(span) try: yield span finally: - self.source._current_span_slot.set( # pylint:disable=protected-access - span_snapshot - ) + with_span(span_snapshot) + finally: if end_on_exit: span.end() @@ -507,7 +518,7 @@ def __init__( ): # TODO: How should multiple TracerSources behave? Should they get their own contexts? # This could be done by adding `str(id(self))` to the slot name. - self._current_span_slot = Context.register_slot("current_span") + self._current_span_name = "current_span" self._active_span_processor = MultiSpanProcessor() self.sampler = sampler self._atexit_handler = None @@ -529,8 +540,11 @@ def get_tracer( ), ) - def get_current_span(self) -> Span: - return self._current_span_slot.get() + def get_current_span( + self, context: Optional[ctx_api.Context] = None + ) -> Span: + """See `opentelemetry.trace.Tracer.get_current_span`.""" + return ctx_api.value(self._current_span_name, context=context) def add_span_processor(self, span_processor: SpanProcessor) -> None: """Registers a new :class:`SpanProcessor` for this `TracerSource`. diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py index 36459c5b73..e68807219f 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py @@ -18,7 +18,7 @@ import typing from enum import Enum -from opentelemetry.context import Context +from opentelemetry import context from opentelemetry.util import time_ns from .. import Span, SpanProcessor @@ -73,7 +73,7 @@ def on_start(self, span: Span) -> None: pass def on_end(self, span: Span) -> None: - with Context.use(suppress_instrumentation=True): + with context.use(suppress_instrumentation=True): try: self.span_exporter.export((span,)) # pylint: disable=broad-except @@ -182,7 +182,7 @@ def export(self) -> None: while idx < self.max_export_batch_size and self.queue: self.spans_list[idx] = self.queue.pop() idx += 1 - with Context.use(suppress_instrumentation=True): + with context.use(suppress_instrumentation=True): try: # Ignore type b/c the Optional[None]+slicing is too "clever" # for mypy diff --git a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py index 1215508269..646909df1a 100644 --- a/opentelemetry-sdk/tests/context/propagation/test_b3_format.py +++ b/opentelemetry-sdk/tests/context/propagation/test_b3_format.py @@ -17,13 +17,10 @@ import opentelemetry.sdk.context.propagation.b3_format as b3_format import opentelemetry.sdk.trace as trace import opentelemetry.trace as trace_api +from opentelemetry.trace.propagation.context import span_context_from_context -FORMAT = b3_format.B3Format() - - -def get_as_list(dict_object, key): - value = dict_object.get(key) - return [value] if value is not None else [] +INJECTOR = b3_format.B3Injector +EXTRACTOR = b3_format.B3Extractor class TestB3Format(unittest.TestCase): @@ -39,38 +36,42 @@ def setUpClass(cls): def test_extract_multi_header(self): """Test the extraction of B3 headers.""" carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: "1", + b3_format.TRACE_ID_KEY: self.serialized_trace_id, + b3_format.SPAN_ID_KEY: self.serialized_span_id, + b3_format.SAMPLED_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) + EXTRACTOR.extract(carrier) new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) + INJECTOR.inject( + new_carrier, set_in_carrier=dict.__setitem__, + ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id + new_carrier[b3_format.TRACE_ID_KEY], self.serialized_trace_id ) self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], self.serialized_span_id + new_carrier[b3_format.SPAN_ID_KEY], self.serialized_span_id ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[b3_format.SAMPLED_KEY], "1") def test_extract_single_header(self): """Test the extraction from a single b3 header.""" carrier = { - FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( + EXTRACTOR.SINGLE_HEADER_KEY: "{}-{}".format( self.serialized_trace_id, self.serialized_span_id ) } - span_context = FORMAT.extract(get_as_list, carrier) + EXTRACTOR.extract(carrier) new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) + INJECTOR.inject( + new_carrier, set_in_carrier=dict.__setitem__, + ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id + new_carrier[b3_format.TRACE_ID_KEY], self.serialized_trace_id ) self.assertEqual( - new_carrier[FORMAT.SPAN_ID_KEY], self.serialized_span_id + new_carrier[b3_format.SPAN_ID_KEY], self.serialized_span_id ) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + self.assertEqual(new_carrier[b3_format.SAMPLED_KEY], "1") def test_extract_header_precedence(self): """A single b3 header should take precedence over multiple @@ -78,108 +79,121 @@ def test_extract_header_precedence(self): """ single_header_trace_id = self.serialized_trace_id[:-3] + "123" carrier = { - FORMAT.SINGLE_HEADER_KEY: "{}-{}".format( + EXTRACTOR.SINGLE_HEADER_KEY: "{}-{}".format( single_header_trace_id, self.serialized_span_id ), - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: "1", + b3_format.TRACE_ID_KEY: self.serialized_trace_id, + b3_format.SPAN_ID_KEY: self.serialized_span_id, + b3_format.SAMPLED_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) + EXTRACTOR.extract(carrier) new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) + INJECTOR.inject( + new_carrier, set_in_carrier=dict.__setitem__, + ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], single_header_trace_id + new_carrier[b3_format.TRACE_ID_KEY], single_header_trace_id ) def test_enabled_sampling(self): """Test b3 sample key variants that turn on sampling.""" for variant in ["1", "True", "true", "d"]: carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: variant, + b3_format.TRACE_ID_KEY: self.serialized_trace_id, + b3_format.SPAN_ID_KEY: self.serialized_span_id, + b3_format.SAMPLED_KEY: variant, } - span_context = FORMAT.extract(get_as_list, carrier) + EXTRACTOR.extract(carrier) new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + INJECTOR.inject( + new_carrier, set_in_carrier=dict.__setitem__, + ) + self.assertEqual(new_carrier[b3_format.SAMPLED_KEY], "1") def test_disabled_sampling(self): """Test b3 sample key variants that turn off sampling.""" for variant in ["0", "False", "false", None]: carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.SAMPLED_KEY: variant, + b3_format.TRACE_ID_KEY: self.serialized_trace_id, + b3_format.SPAN_ID_KEY: self.serialized_span_id, + b3_format.SAMPLED_KEY: variant, } - span_context = FORMAT.extract(get_as_list, carrier) + EXTRACTOR.extract(carrier) new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "0") + INJECTOR.inject( + new_carrier, set_in_carrier=dict.__setitem__, + ) + self.assertEqual(new_carrier[b3_format.SAMPLED_KEY], "0") def test_flags(self): """x-b3-flags set to "1" should result in propagation.""" carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + b3_format.TRACE_ID_KEY: self.serialized_trace_id, + b3_format.SPAN_ID_KEY: self.serialized_span_id, + EXTRACTOR.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) + + EXTRACTOR.extract(carrier) new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + INJECTOR.inject( + new_carrier, set_in_carrier=dict.__setitem__, + ) + self.assertEqual(new_carrier[b3_format.SAMPLED_KEY], "1") def test_flags_and_sampling(self): """Propagate if b3 flags and sampling are set.""" carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + b3_format.TRACE_ID_KEY: self.serialized_trace_id, + b3_format.SPAN_ID_KEY: self.serialized_span_id, + EXTRACTOR.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) + EXTRACTOR.extract(carrier) new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) - self.assertEqual(new_carrier[FORMAT.SAMPLED_KEY], "1") + INJECTOR.inject( + new_carrier, set_in_carrier=dict.__setitem__, + ) + self.assertEqual(new_carrier[b3_format.SAMPLED_KEY], "1") def test_64bit_trace_id(self): """64 bit trace ids should be padded to 128 bit trace ids.""" trace_id_64_bit = self.serialized_trace_id[:16] carrier = { - FORMAT.TRACE_ID_KEY: trace_id_64_bit, - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + b3_format.TRACE_ID_KEY: trace_id_64_bit, + b3_format.SPAN_ID_KEY: self.serialized_span_id, + EXTRACTOR.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) + EXTRACTOR.extract(carrier) new_carrier = {} - FORMAT.inject(span_context, dict.__setitem__, new_carrier) + INJECTOR.inject( + new_carrier, set_in_carrier=dict.__setitem__, + ) self.assertEqual( - new_carrier[FORMAT.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit + new_carrier[b3_format.TRACE_ID_KEY], "0" * 16 + trace_id_64_bit ) def test_invalid_single_header(self): """If an invalid single header is passed, return an invalid SpanContext. """ - carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} - span_context = FORMAT.extract(get_as_list, carrier) + carrier = {EXTRACTOR.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} + span_context = span_context_from_context(EXTRACTOR.extract(carrier)) self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) def test_missing_trace_id(self): """If a trace id is missing, populate an invalid trace id.""" carrier = { - FORMAT.SPAN_ID_KEY: self.serialized_span_id, - FORMAT.FLAGS_KEY: "1", + b3_format.SPAN_ID_KEY: self.serialized_span_id, + EXTRACTOR.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) + span_context = span_context_from_context(EXTRACTOR.extract(carrier)) self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) def test_missing_span_id(self): """If a trace id is missing, populate an invalid trace id.""" carrier = { - FORMAT.TRACE_ID_KEY: self.serialized_trace_id, - FORMAT.FLAGS_KEY: "1", + b3_format.TRACE_ID_KEY: self.serialized_trace_id, + EXTRACTOR.FLAGS_KEY: "1", } - span_context = FORMAT.extract(get_as_list, carrier) + span_context = span_context_from_context(EXTRACTOR.extract(carrier)) self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) diff --git a/opentelemetry-api/tests/context/propagation/test_tracecontexthttptextformat.py b/opentelemetry-sdk/tests/context/propagation/test_tracecontexthttptextformat.py similarity index 63% rename from opentelemetry-api/tests/context/propagation/test_tracecontexthttptextformat.py rename to opentelemetry-sdk/tests/context/propagation/test_tracecontexthttptextformat.py index ed952e0dba..47853563fc 100644 --- a/opentelemetry-api/tests/context/propagation/test_tracecontexthttptextformat.py +++ b/opentelemetry-sdk/tests/context/propagation/test_tracecontexthttptextformat.py @@ -16,22 +16,27 @@ import unittest from opentelemetry import trace -from opentelemetry.context.propagation import tracecontexthttptextformat +from opentelemetry.context import current +from opentelemetry.sdk.context.propagation.tracecontexthttptextformat import ( + TraceContextHTTPExtractor, + TraceContextHTTPInjector, +) +from opentelemetry.trace.propagation.context import ( + span_context_from_context, + with_span_context, +) -FORMAT = tracecontexthttptextformat.TraceContextHTTPTextFormat() - - -def get_as_list( - dict_object: typing.Dict[str, typing.List[str]], key: str -) -> typing.List[str]: - value = dict_object.get(key) - return value if value is not None else [] +INJECT = TraceContextHTTPInjector +EXTRACT = TraceContextHTTPExtractor class TestTraceContextFormat(unittest.TestCase): TRACE_ID = int("12345678901234567890123456789012", 16) # type:int SPAN_ID = int("1234567890123456", 16) # type:int + def setUp(self): + self.ctx = current() + def test_no_traceparent_header(self): """When tracecontext headers are not present, a new SpanContext should be created. @@ -41,7 +46,9 @@ def test_no_traceparent_header(self): If no traceparent header is received, the vendor creates a new trace-id and parent-id that represents the current request. """ output = {} # type:typing.Dict[str, typing.List[str]] - span_context = FORMAT.extract(get_as_list, output) + span_context = span_context_from_context( + EXTRACT.extract(output, self.ctx) + ) self.assertTrue(isinstance(span_context, trace.SpanContext)) def test_headers_with_tracestate(self): @@ -53,12 +60,14 @@ def test_headers_with_tracestate(self): span_id=format(self.SPAN_ID, "016x"), ) tracestate_value = "foo=1,bar=2,baz=3" - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [traceparent_value], - "tracestate": [tracestate_value], - }, + span_context = span_context_from_context( + EXTRACT.extract( + { + "traceparent": [traceparent_value], + "tracestate": [tracestate_value], + }, + self.ctx, + ) ) self.assertEqual(span_context.trace_id, self.TRACE_ID) self.assertEqual(span_context.span_id, self.SPAN_ID) @@ -67,7 +76,7 @@ def test_headers_with_tracestate(self): ) output = {} # type:typing.Dict[str, str] - FORMAT.inject(span_context, dict.__setitem__, output) + INJECT.inject(output, set_in_carrier=dict.__setitem__) self.assertEqual(output["traceparent"], traceparent_value) for pair in ["foo=1", "bar=2", "baz=3"]: self.assertIn(pair, output["tracestate"]) @@ -89,14 +98,16 @@ def test_invalid_trace_id(self): If the vendor failed to parse traceparent, it MUST NOT attempt to parse tracestate. Note that the opposite is not true: failure to parse tracestate MUST NOT affect the parsing of traceparent. """ - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [ - "00-00000000000000000000000000000000-1234567890123456-00" - ], - "tracestate": ["foo=1,bar=2,foo=3"], - }, + span_context = span_context_from_context( + EXTRACT.extract( + { + "traceparent": [ + "00-00000000000000000000000000000000-1234567890123456-00" + ], + "tracestate": ["foo=1,bar=2,foo=3"], + }, + self.ctx, + ) ) self.assertEqual(span_context, trace.INVALID_SPAN_CONTEXT) @@ -115,14 +126,16 @@ def test_invalid_parent_id(self): If the vendor failed to parse traceparent, it MUST NOT attempt to parse tracestate. Note that the opposite is not true: failure to parse tracestate MUST NOT affect the parsing of traceparent. """ - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [ - "00-00000000000000000000000000000000-0000000000000000-00" - ], - "tracestate": ["foo=1,bar=2,foo=3"], - }, + span_context = span_context_from_context( + EXTRACT.extract( + { + "traceparent": [ + "00-00000000000000000000000000000000-0000000000000000-00" + ], + "tracestate": ["foo=1,bar=2,foo=3"], + }, + self.ctx, + ) ) self.assertEqual(span_context, trace.INVALID_SPAN_CONTEXT) @@ -134,11 +147,10 @@ def test_no_send_empty_tracestate(self): Empty and whitespace-only list members are allowed. Vendors MUST accept empty tracestate headers but SHOULD avoid sending them. """ + ctx = with_span_context(trace.SpanContext(self.TRACE_ID, self.SPAN_ID)) output = {} # type:typing.Dict[str, str] - FORMAT.inject( - trace.SpanContext(self.TRACE_ID, self.SPAN_ID), - dict.__setitem__, - output, + INJECT.inject( + output, ctx, dict.__setitem__, ) self.assertTrue("traceparent" in output) self.assertFalse("tracestate" in output) @@ -151,14 +163,16 @@ def test_format_not_supported(self): If the version cannot be parsed, return an invalid trace header. """ - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [ - "00-12345678901234567890123456789012-1234567890123456-00-residue" - ], - "tracestate": ["foo=1,bar=2,foo=3"], - }, + span_context = span_context_from_context( + EXTRACT.extract( + { + "traceparent": [ + "00-12345678901234567890123456789012-1234567890123456-00-residue" + ], + "tracestate": ["foo=1,bar=2,foo=3"], + }, + self.ctx, + ) ) self.assertEqual(span_context, trace.INVALID_SPAN_CONTEXT) @@ -166,32 +180,37 @@ def test_propagate_invalid_context(self): """Do not propagate invalid trace context. """ output = {} # type:typing.Dict[str, str] - FORMAT.inject(trace.INVALID_SPAN_CONTEXT, dict.__setitem__, output) + ctx = with_span_context(trace.INVALID_SPAN_CONTEXT) + INJECT.inject(output, ctx, dict.__setitem__) self.assertFalse("traceparent" in output) def test_tracestate_empty_header(self): """Test tracestate with an additional empty header (should be ignored)""" - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [ - "00-12345678901234567890123456789012-1234567890123456-00" - ], - "tracestate": ["foo=1", ""], - }, + span_context = span_context_from_context( + EXTRACT.extract( + { + "traceparent": [ + "00-12345678901234567890123456789012-1234567890123456-00" + ], + "tracestate": ["foo=1", ""], + }, + self.ctx, + ) ) self.assertEqual(span_context.trace_state["foo"], "1") def test_tracestate_header_with_trailing_comma(self): """Do not propagate invalid trace context. """ - span_context = FORMAT.extract( - get_as_list, - { - "traceparent": [ - "00-12345678901234567890123456789012-1234567890123456-00" - ], - "tracestate": ["foo=1,"], - }, + span_context = span_context_from_context( + EXTRACT.extract( + { + "traceparent": [ + "00-12345678901234567890123456789012-1234567890123456-00" + ], + "tracestate": ["foo=1,"], + }, + self.ctx, + ) ) self.assertEqual(span_context.trace_state["foo"], "1") diff --git a/opentelemetry-sdk/tests/context/test_futures.py b/opentelemetry-sdk/tests/context/test_futures.py new file mode 100644 index 0000000000..5761bca3ed --- /dev/null +++ b/opentelemetry-sdk/tests/context/test_futures.py @@ -0,0 +1,84 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import concurrent.futures +import unittest + +from opentelemetry import context +from opentelemetry.sdk import trace +from opentelemetry.sdk.trace import export +from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, +) + + +class TestContextWithFutures(unittest.TestCase): + span_names = [ + "test_span1", + "test_span2", + "test_span3", + "test_span4", + "test_span5", + ] + + def do_work(self, name="default"): + with self.tracer.start_as_current_span(name): + context.set_value("say-something", "bar") + + def setUp(self): + self.tracer_source = trace.TracerSource() + self.tracer = self.tracer_source.get_tracer(__name__) + self.memory_exporter = InMemorySpanExporter() + span_processor = export.SimpleExportSpanProcessor(self.memory_exporter) + self.tracer_source.add_span_processor(span_processor) + + def test_with_futures(self): + try: + import contextvars # pylint: disable=import-outside-toplevel + except ImportError: + self.skipTest("contextvars not available") + + with self.tracer.start_as_current_span("futures_test"): + with concurrent.futures.ThreadPoolExecutor( + max_workers=5 + ) as executor: + # Start the load operations + for span in self.span_names: + executor.submit( + contextvars.copy_context().run, self.do_work, span, + ) + span_list = self.memory_exporter.get_finished_spans() + span_names_list = [span.name for span in span_list] + + expected = [ + "test_span1", + "test_span2", + "test_span3", + "test_span4", + "test_span5", + "futures_test", + ] + + self.assertCountEqual(span_names_list, expected) + span_names_list.sort() + expected.sort() + self.assertListEqual(span_names_list, expected) + # expected_parent = next( + # span for span in span_list if span.name == "futures_test" + # ) + # TODO: ensure the following passes + # for span in span_list: + # if span is expected_parent: + # continue + # self.assertEqual(span.parent, expected_parent) + # diff --git a/opentelemetry-sdk/tests/context/test_threads.py b/opentelemetry-sdk/tests/context/test_threads.py new file mode 100644 index 0000000000..39c6428f76 --- /dev/null +++ b/opentelemetry-sdk/tests/context/test_threads.py @@ -0,0 +1,75 @@ +# Copyright 2019, OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from multiprocessing.dummy import Pool as ThreadPool + +from opentelemetry import context +from opentelemetry.sdk import trace +from opentelemetry.sdk.trace import export +from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, +) +from opentelemetry.trace.propagation.context import span_from_context + + +class TestContext(unittest.TestCase): + span_names = [ + "test_span1", + "test_span2", + "test_span3", + "test_span4", + "test_span5", + ] + + def do_work(self, name="default"): + with self.tracer.start_as_current_span(name): + context.set_value("say-something", "bar") + + def setUp(self): + self.tracer_source = trace.TracerSource() + self.tracer = self.tracer_source.get_tracer(__name__) + self.memory_exporter = InMemorySpanExporter() + span_processor = export.SimpleExportSpanProcessor(self.memory_exporter) + self.tracer_source.add_span_processor(span_processor) + + def test_with_threads(self): + with self.tracer.start_as_current_span("threads_test"): + print(span_from_context()) + pool = ThreadPool(5) # create a thread pool + pool.map(self.do_work, self.span_names) + pool.close() + pool.join() + span_list = self.memory_exporter.get_finished_spans() + span_names_list = [span.name for span in span_list] + expected = [ + "test_span1", + "test_span2", + "test_span3", + "test_span4", + "test_span5", + "threads_test", + ] + self.assertCountEqual(span_names_list, expected) + span_names_list.sort() + expected.sort() + self.assertListEqual(span_names_list, expected) + # expected_parent = next( + # span for span in span_list if span.name == "threads_test" + # ) + # TODO: ensure the following passes + # for span in span_list: + # if span is expected_parent: + # continue + # self.assertEqual(span.parent, expected_parent) diff --git a/opentelemetry-sdk/tests/distributedcontext/__init__.py b/opentelemetry-sdk/tests/correlationcontext/__init__.py similarity index 100% rename from opentelemetry-sdk/tests/distributedcontext/__init__.py rename to opentelemetry-sdk/tests/correlationcontext/__init__.py diff --git a/opentelemetry-sdk/tests/distributedcontext/test_distributed_context.py b/opentelemetry-sdk/tests/correlationcontext/test_distributed_context.py similarity index 62% rename from opentelemetry-sdk/tests/distributedcontext/test_distributed_context.py rename to opentelemetry-sdk/tests/correlationcontext/test_distributed_context.py index eddb61330d..2f151fea8e 100644 --- a/opentelemetry-sdk/tests/distributedcontext/test_distributed_context.py +++ b/opentelemetry-sdk/tests/correlationcontext/test_distributed_context.py @@ -14,29 +14,29 @@ import unittest -from opentelemetry import distributedcontext as dctx_api -from opentelemetry.sdk import distributedcontext +from opentelemetry import correlationcontext as cctx_api +from opentelemetry.sdk import correlationcontext -class TestDistributedContextManager(unittest.TestCase): +class TestCorrelationContextManager(unittest.TestCase): def setUp(self): - self.manager = distributedcontext.DistributedContextManager() + self.manager = correlationcontext.CorrelationContextManager() def test_use_context(self): # Context is None initially - self.assertIsNone(self.manager.get_current_context()) + self.assertIsNone(self.manager.current_context()) # Start initial context - dctx = dctx_api.DistributedContext(()) + dctx = cctx_api.CorrelationContext() with self.manager.use_context(dctx) as current: self.assertIs(current, dctx) - self.assertIs(self.manager.get_current_context(), dctx) + self.assertIs(self.manager.current_context(), dctx) # Context is overridden - nested_dctx = dctx_api.DistributedContext(()) + nested_dctx = cctx_api.CorrelationContext() with self.manager.use_context(nested_dctx) as current: self.assertIs(current, nested_dctx) - self.assertIs(self.manager.get_current_context(), nested_dctx) + self.assertIs(self.manager.current_context(), nested_dctx) # Context is restored - self.assertIs(self.manager.get_current_context(), dctx) + self.assertIs(self.manager.current_context(), dctx) diff --git a/opentelemetry-sdk/tests/trace/test_trace.py b/opentelemetry-sdk/tests/trace/test_trace.py index 98a7bb100e..4486004845 100644 --- a/opentelemetry-sdk/tests/trace/test_trace.py +++ b/opentelemetry-sdk/tests/trace/test_trace.py @@ -18,6 +18,7 @@ from unittest import mock from opentelemetry import trace as trace_api +from opentelemetry.context import new_context, set_current from opentelemetry.sdk import trace from opentelemetry.trace import sampling from opentelemetry.trace.status import StatusCanonicalCode @@ -130,6 +131,9 @@ def test_sampler_no_sampling(self): class TestSpanCreation(unittest.TestCase): + def setUp(self): + set_current(new_context()) + def test_start_span_invalid_spancontext(self): """If an invalid span context is passed as the parent, the created span should use a new span id. diff --git a/tests/w3c_tracecontext_validation_server.py b/tests/w3c_tracecontext_validation_server.py index bea4d4fde5..eeed85f989 100644 --- a/tests/w3c_tracecontext_validation_server.py +++ b/tests/w3c_tracecontext_validation_server.py @@ -23,15 +23,20 @@ import flask import requests -from opentelemetry import trace +from opentelemetry import propagation, trace from opentelemetry.ext import http_requests from opentelemetry.ext.wsgi import OpenTelemetryMiddleware +from opentelemetry.sdk.context.propagation import tracecontexthttptextformat from opentelemetry.sdk.trace import TracerSource from opentelemetry.sdk.trace.export import ( ConsoleSpanExporter, SimpleExportSpanProcessor, ) +(w3c_extractor, w3c_injector) = tracecontexthttptextformat.http_propagator() +propagation.set_http_extractors([w3c_extractor]) +propagation.set_http_injectors([w3c_injector]) + # The preferred tracer implementation must be set, as the opentelemetry-api # defines the interface with a no-op implementation. trace.set_preferred_tracer_source_implementation(lambda T: TracerSource())