Source code for mlflow.tracing.distributed

import logging
from contextlib import contextmanager

from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

import mlflow
from mlflow.telemetry.events import TracingContextPropagation
from mlflow.telemetry.track import record_usage_event
from mlflow.tracing.provider import get_context_api, get_current_context, get_current_otel_span

_logger = logging.getLogger(__name__)


[docs]@record_usage_event(TracingContextPropagation) def get_tracing_context_headers_for_http_request() -> dict[str, str]: """ Get the http request headers that hold information of the tracing context. The trace context is serialized as the traceparent header which is defined in the W3C TraceContext specification. For details, you can refer to https://opentelemetry.io/docs/concepts/context-propagation/ and https://www.w3.org/TR/trace-context/#traceparent-header Returns: The http request headers that hold information of the tracing context. Example (client code): .. code-block:: python import mlflow from mlflow.tracing import get_tracing_context_headers_for_http_request with mlflow.start_span("client-root") as client_span: # Get the headers that hold information of the tracing context, # and send request to remote agent with the headers headers = get_tracing_context_headers_for_http_request() resp = requests.post(f"{base_url}/remote_agent_handler", headers=headers) Example (server handler code): .. code-block:: python import mlflow from flask import Flask, request from mlflow.tracing import set_tracing_context_from_http_request_headers app = Flask(__name__) @app.post("/agent-handler") def handle(): headers = dict(request.headers) with set_tracing_context_from_http_request_headers(headers): with mlflow.start_span("server-handler") as span: # call agent ... span.set_attribute("key", "value") """ active_span = mlflow.get_current_active_span() if active_span is None: _logger.warning( "No active span found for fetching the trace context from. Returning an empty header." ) headers = {} TraceContextTextMapPropagator().inject(carrier=headers, context=get_current_context()) return headers
[docs]@record_usage_event(TracingContextPropagation) @contextmanager def set_tracing_context_from_http_request_headers(headers: dict[str, str]): """ Context manager to extract the trace context from the http request headers and set the extracted trace context as the current trace context within the scope of this context manager. The trace context must be serialized as the 'traceparent' header which is defined in the W3C TraceContext specification, please see :py:func:`mlflow.tracing.get_tracing_context_headers_for_http_request` for how to get the http request headers. Args: headers: The http request headers to extract the trace context from. Example (client code): .. code-block:: python import mlflow from mlflow.tracing import get_tracing_context_headers_for_http_request with mlflow.start_span("client-root") as client_span: # Get the headers that hold information of the tracing context, # and send request to remote agent with the headers headers = get_tracing_context_headers_for_http_request() resp = requests.post(f"{base_url}/remote_agent_handler", headers=headers) Example (server handler code): .. code-block:: python import mlflow from flask import Flask, request from mlflow.tracing import set_tracing_context_from_http_request_headers app = Flask(__name__) @app.post("/agent-handler") def handle(): headers = dict(request.headers) with set_tracing_context_from_http_request_headers(headers): with mlflow.start_span("server-handler") as span: # call agent ... span.set_attribute("key", "value") """ from mlflow import MlflowException from mlflow.entities.trace_info import TraceInfo, TraceState from mlflow.tracing.trace_manager import InMemoryTraceManager from mlflow.tracing.utils import generate_mlflow_trace_id_from_otel_trace_id token = None otel_trace_id = None trace_manager = InMemoryTraceManager.get_instance() try: headers = dict(headers) if "Traceparent" in headers: # Note: Some http server framework (e.g. flask) converts http header key # first letter to upper case, but `TraceContextTextMapPropagator` can't # recognize the key 'Traceparent', so that convert it to lower case. traceparent = headers.pop("Traceparent") headers["traceparent"] = traceparent if "traceparent" not in headers: raise MlflowException.invalid_parameter_value( "The http request headers do not contain the required key 'traceparent', " "please generate the request headers " "by 'mlflow.tracing.distributed.get_tracing_context_headers_for_http_request' " "API." ) ctx = TraceContextTextMapPropagator().extract(headers) token = get_context_api().attach(ctx) extracted_span = get_current_otel_span() span_context = extracted_span.get_span_context() otel_trace_id = span_context.trace_id trace_id = generate_mlflow_trace_id_from_otel_trace_id(otel_trace_id) dummy_trace_info = TraceInfo( trace_id=trace_id, trace_location=None, request_time=None, state=TraceState.IN_PROGRESS, ) trace_manager.register_trace(otel_trace_id, dummy_trace_info, is_remote_trace=True) yield finally: if token is not None: get_context_api().detach(token) if otel_trace_id is not None: trace_manager.pop_trace(otel_trace_id)