Source code for mlflow.tracing.utils

# TODO: Split this file into multiple files and move under utils directory.
from __future__ import annotations

import inspect
import json
import logging
import uuid
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import asdict, is_dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Generator

from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan
from opentelemetry.sdk.trace import Span as OTelSpan

from mlflow.exceptions import BAD_REQUEST, MlflowException, MlflowTracingException
from mlflow.tracing.constant import (
    ASSESSMENT_ID_PREFIX,
    TRACE_ID_V4_PREFIX,
    TRACE_REQUEST_ID_PREFIX,
    SpanAttributeKey,
    TokenUsageKey,
    TraceMetadataKey,
    TraceSizeStatsKey,
)
from mlflow.tracing.constant import (
    CostKey as CostKey,
)
from mlflow.utils.mlflow_tags import IMMUTABLE_TAGS
from mlflow.version import IS_TRACING_SDK_ONLY

_logger = logging.getLogger(__name__)

SPANS_COLUMN_NAME = "spans"

if TYPE_CHECKING:
    from mlflow.entities import LiveSpan, Trace
    from mlflow.pyfunc.context import Context
    from mlflow.types.chat import ChatTool


def capture_function_input_args(func, args, kwargs) -> dict[str, Any] | None:
    try:
        func_signature = inspect.signature(func)
        bound_arguments = func_signature.bind(*args, **kwargs)
        bound_arguments.apply_defaults()

        # Remove `self` from bound arguments if it exists
        if bound_arguments.arguments.get("self"):
            del bound_arguments.arguments["self"]

        # Remove `cls` from bound arguments if it's the first parameter and it's a type
        # This detects classmethods more reliably
        params = list(bound_arguments.arguments.keys())
        if params and params[0] == "cls" and isinstance(bound_arguments.arguments["cls"], type):
            del bound_arguments.arguments["cls"]

        return bound_arguments.arguments
    except Exception:
        _logger.warning(f"Failed to capture inputs for function {func.__name__}.")
        return None


class TraceJSONEncoder(json.JSONEncoder):
    """
    Custom JSON encoder for serializing non-OpenTelemetry compatible objects in a trace or span.

    Trace may contain types that require custom serialization logic, such as Pydantic models,
    non-JSON-serializable types, etc.
    """

    def default(self, obj):
        try:
            import pydantic

            if isinstance(obj, pydantic.BaseModel):
                return obj.model_dump()
        except ImportError:
            pass

        # Some dataclass object defines __str__ method that doesn't return the full object
        # representation, so we use dict representation instead.
        # E.g. https://github.com/run-llama/llama_index/blob/29ece9b058f6b9a1cf29bc723ed4aa3a39879ad5/llama-index-core/llama_index/core/chat_engine/types.py#L63-L64
        if is_dataclass(obj):
            try:
                return asdict(obj)
            except TypeError:
                pass

        # Some object has dangerous side effect in __str__ method, so we use class name instead.
        if not self._is_safe_to_encode_str(obj):
            return type(obj)

        try:
            return super().default(obj)
        except TypeError:
            return str(obj)

    def _is_safe_to_encode_str(self, obj) -> bool:
        """Check if it's safe to encode the object as a string."""
        try:
            # These Llama Index objects are not safe to encode as string, because their __str__
            # method consumes the stream and make it unusable.
            # E.g. https://github.com/run-llama/llama_index/blob/54f2da61ba8a573284ab8336f2b2810d948c3877/llama-index-core/llama_index/core/base/response/schema.py#L120-L127
            from llama_index.core.base.response.schema import (
                AsyncStreamingResponse,
                StreamingResponse,
            )
            from llama_index.core.chat_engine.types import StreamingAgentChatResponse

            if isinstance(
                obj,
                (AsyncStreamingResponse, StreamingResponse, StreamingAgentChatResponse),
            ):
                return False
        except ImportError:
            pass

        return True


def dump_span_attribute_value(value: Any) -> str:
    # NB: OpenTelemetry attribute can store not only string but also a few primitives like
    #   int, float, bool, and list of them. However, we serialize all into JSON string here
    #   for the simplicity in deserialization process.
    return json.dumps(value, cls=TraceJSONEncoder, ensure_ascii=False)


@lru_cache(maxsize=1)
def encode_span_id(span_id: int) -> str:
    """
    Encode the given integer span ID to a 16-byte hex string.
    # https://github.com/open-telemetry/opentelemetry-python/blob/9398f26ecad09e02ad044859334cd4c75299c3cd/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py#L507-L508
    # NB: We don't add '0x' prefix to the hex string here for simpler parsing in backend.
    #   Some backend (e.g. Databricks) disallow this prefix.
    """
    return trace_api.format_span_id(span_id)


@lru_cache(maxsize=1)
def encode_trace_id(trace_id: int) -> str:
    """
    Encode the given integer trace ID to a 32-byte hex string.
    """
    return trace_api.format_trace_id(trace_id)


def decode_id(span_or_trace_id: str) -> int:
    """
    Decode the given hex string span or trace ID to an integer.
    """
    return int(span_or_trace_id, 16)


def get_mlflow_span_for_otel_span(span: OTelSpan) -> LiveSpan | None:
    """
    Get the active MLflow span for the given OpenTelemetry span.
    """
    from mlflow.tracing.trace_manager import InMemoryTraceManager

    trace_id = get_otel_attribute(span, SpanAttributeKey.REQUEST_ID)
    mlflow_span_id = encode_span_id(span.get_span_context().span_id)
    return InMemoryTraceManager.get_instance().get_span_from_id(trace_id, mlflow_span_id)


def build_otel_context(trace_id: int, span_id: int) -> trace_api.SpanContext:
    """
    Build an OpenTelemetry SpanContext object from the given trace and span IDs.
    """
    return trace_api.SpanContext(
        trace_id=trace_id,
        span_id=span_id,
        # NB: This flag is OpenTelemetry's concept to indicate whether the context is
        # propagated from remote parent or not. We don't support distributed tracing
        # yet so always set it to False.
        is_remote=False,
    )


def _aggregate_from_spans(
    spans: list[LiveSpan],
    attribute_key: str,
    input_key: str,
    output_key: str,
    total_key: str,
    default: int | float,
) -> dict[str, int | float] | None:
    """Generic aggregation of data from spans using DFS traversal.

    Avoids double-counting by skipping spans whose ancestors already have the data.

    Args:
        spans: List of spans to aggregate from.
        attribute_key: The span attribute key to look up.
        input_key: Key for extracting input value from span data.
        output_key: Key for extracting output value from span data.
        total_key: Key for extracting total value from span data.
        default: Default value (0 for int, 0.0 for float) that also determines return type.

    Returns:
        Aggregated dictionary with the keys, or None if no data found.
    """
    input_val = default
    output_val = default
    total_val = default
    has_data = False

    span_id_to_spans = {span.span_id: span for span in spans}
    children_map: defaultdict[str, list[LiveSpan]] = defaultdict(list)
    roots: list[LiveSpan] = []

    for span in spans:
        parent_id = span.parent_id
        if parent_id and parent_id in span_id_to_spans:
            children_map[parent_id].append(span)
        else:
            roots.append(span)

    def dfs(span: LiveSpan, ancestor_has_data: bool) -> None:
        nonlocal input_val, output_val, total_val, has_data

        data = span.get_attribute(attribute_key)
        span_has_data = data is not None

        if span_has_data and not ancestor_has_data:
            input_val += data.get(input_key, default)
            output_val += data.get(output_key, default)
            total_val += data.get(total_key, default)
            has_data = True

        next_ancestor_has_data = ancestor_has_data or span_has_data
        for child in children_map.get(span.span_id, []):
            dfs(child, next_ancestor_has_data)

    for root in roots:
        dfs(root, False)

    if not has_data:
        return None

    return {
        input_key: input_val,
        output_key: output_val,
        total_key: total_val,
    }


def aggregate_usage_from_spans(spans: list[LiveSpan]) -> dict[str, int] | None:
    """Aggregate token usage information from all spans in the trace."""
    return _aggregate_from_spans(
        spans,
        SpanAttributeKey.CHAT_USAGE,
        TokenUsageKey.INPUT_TOKENS,
        TokenUsageKey.OUTPUT_TOKENS,
        TokenUsageKey.TOTAL_TOKENS,
        0,
    )


def aggregate_cost_from_spans(spans: list[LiveSpan]) -> dict[str, float] | None:
    """Aggregate cost information from all spans in the trace."""
    return _aggregate_from_spans(
        spans,
        SpanAttributeKey.LLM_COST,
        CostKey.INPUT_COST,
        CostKey.OUTPUT_COST,
        CostKey.TOTAL_COST,
        0.0,
    )


def calculate_span_cost(span: LiveSpan) -> dict[str, float] | None:
    """Calculate cost for a single span using LiteLLM pricing data.

    Args:
        span: The span to calculate cost for.

    Returns:
        Dictionary with input_cost, output_cost, and total_cost in USD,
        or None if cost cannot be calculated.
    """
    model_name = span.get_attribute(SpanAttributeKey.MODEL)
    usage = span.get_attribute(SpanAttributeKey.CHAT_USAGE)
    model_provider = span.get_attribute(SpanAttributeKey.MODEL_PROVIDER)
    return calculate_cost_by_model_and_token_usage(model_name, usage, model_provider)


def calculate_cost_by_model_and_token_usage(
    model_name: str | None, usage: dict[str, int] | None, model_provider: str | None = None
) -> dict[str, float] | None:
    if not model_name or not usage:
        return None

    try:
        from litellm import cost_per_token
    except ImportError:
        _logger.debug("LiteLLM not available for cost calculation")
        return None

    prompt_tokens = usage.get(TokenUsageKey.INPUT_TOKENS, 0)
    completion_tokens = usage.get(TokenUsageKey.OUTPUT_TOKENS, 0)

    if prompt_tokens == 0 and completion_tokens == 0:
        return None

    try:
        input_cost_usd, output_cost_usd = cost_per_token(
            model=model_name, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
        )
    except Exception as e:
        if model_provider:
            # pass model_provider only in exception case to avoid invalid model_provider
            # being used when model_name itself is enough to calculate cost, since model_provider
            # field can be with any value and litellm may not support it.
            try:
                input_cost_usd, output_cost_usd = cost_per_token(
                    model=model_name,
                    prompt_tokens=prompt_tokens,
                    completion_tokens=completion_tokens,
                    custom_llm_provider=model_provider,
                )
            except Exception as e:
                _logger.debug(
                    f"Failed to calculate cost for model {model_name}: {e}", exc_info=True
                )
                return None
        else:
            _logger.debug(
                f"Failed to calculate cost for model {model_name} without provider: {e}",
                exc_info=True,
            )
            return None

    return {
        CostKey.INPUT_COST: input_cost_usd,
        CostKey.OUTPUT_COST: output_cost_usd,
        CostKey.TOTAL_COST: input_cost_usd + output_cost_usd,
    }


def get_otel_attribute(span: trace_api.Span, key: str) -> str | None:
    """
    Get the attribute value from the OpenTelemetry span in a decoded format.

    Args:
        span: The OpenTelemetry span object.
        key: The key of the attribute to retrieve.

    Returns:
        The attribute value as decoded string. If the attribute is not found or cannot
        be parsed, return None.
    """
    try:
        attribute_value = span.attributes.get(key)
        if attribute_value is None:
            return None
        return json.loads(attribute_value)
    except Exception:
        _logger.debug(f"Failed to get attribute {key} with from span {span}.", exc_info=True)


def _try_get_prediction_context():
    # NB: Tracing is enabled in mlflow-skinny, but the pyfunc module cannot be imported as it
    #     relies on numpy, which is not installed in skinny.
    try:
        from mlflow.pyfunc.context import get_prediction_context
    except (ImportError, KeyError):
        return

    return get_prediction_context()


def maybe_get_request_id(is_evaluate=False) -> str | None:
    """Get the request ID if the current prediction is as a part of MLflow model evaluation."""
    context = _try_get_prediction_context()
    if not context or (is_evaluate and not context.is_evaluate):
        return None

    if not context.request_id and is_evaluate:
        _logger.warning(
            f"Missing request_id for context {context}. request_id can't be None when "
            "is_evaluate=True. This is likely an internal error of MLflow, please file "
            "a bug report at https://github.com/mlflow/mlflow/issues."
        )
        return None

    return context.request_id


def maybe_get_dependencies_schemas() -> dict[str, Any] | None:
    if context := _try_get_prediction_context():
        return context.dependencies_schemas


def maybe_get_logged_model_id() -> str | None:
    """
    Get the logged model ID associated with the current prediction context.
    """
    if context := _try_get_prediction_context():
        return context.model_id


def exclude_immutable_tags(tags: dict[str, str]) -> dict[str, str]:
    """Exclude immutable tags e.g. "mlflow.user" from the given tags."""
    return {k: v for k, v in tags.items() if k not in IMMUTABLE_TAGS}


def generate_mlflow_trace_id_from_otel_trace_id(otel_trace_id: int) -> str:
    """
    Generate an MLflow trace ID from an OpenTelemetry trace ID.

    Args:
        otel_trace_id: The OpenTelemetry trace ID as an integer.

    Returns:
        The MLflow trace ID string in format "tr-<hex_trace_id>".
    """
    return TRACE_REQUEST_ID_PREFIX + encode_trace_id(otel_trace_id)


def generate_trace_id_v4_from_otel_trace_id(otel_trace_id: int, location: str) -> str:
    """
    Generate a trace ID in v4 format from the given OpenTelemetry trace ID.

    Args:
        otel_trace_id: The OpenTelemetry trace ID as an integer.
        location: The location, of the trace.

    Returns:
        The MLflow trace ID string in format "trace:/<location>/<hex_trace_id>".
    """
    return construct_trace_id_v4(location, encode_trace_id(otel_trace_id))


def generate_trace_id_v4(span: OTelSpan, location: str) -> str:
    """
    Generate a trace ID for the given span.

    Args:
        span: The OpenTelemetry span object.
        location: The location, of the trace.

    Returns:
        Trace ID with format "trace:/<location>/<hex_trace_id>".
    """
    return generate_trace_id_v4_from_otel_trace_id(span.context.trace_id, location)


def generate_trace_id_v3(span: OTelSpan) -> str:
    """
    Generate a trace ID for the given span (V3 trace schema).

    The format will be "tr-<trace_id>" where the trace_id is hex-encoded Otel trace ID.
    """
    return generate_mlflow_trace_id_from_otel_trace_id(span.context.trace_id)


def generate_request_id_v2() -> str:
    """
    Generate a request ID for the given span.

    This should only be used for V2 trace schema where we use a random UUID as
    request ID. In the V3 schema, "request_id" is renamed to "trace_id" and
    we use the otel-generated trace ID with encoding.
    """
    return uuid.uuid4().hex


def construct_full_inputs(func, *args, **kwargs) -> dict[str, Any]:
    """
    Construct the full input arguments dictionary for the given function,
    including positional and keyword arguments.
    """
    signature = inspect.signature(func)
    # this does not create copy. So values should not be mutated directly
    arguments = signature.bind_partial(*args, **kwargs).arguments

    if "self" in arguments:
        arguments.pop("self")

    return arguments


@contextmanager
def maybe_set_prediction_context(context: "Context" | None):
    """
    Set the prediction context if the given context
    is not None. Otherwise no-op.
    """
    if not IS_TRACING_SDK_ONLY and context:
        from mlflow.pyfunc.context import set_prediction_context

        with set_prediction_context(context):
            yield
    else:
        yield


[docs]def set_span_chat_tools(span: LiveSpan, tools: list[ChatTool]): """ Set the `mlflow.chat.tools` attribute on the specified span. This attribute is used in the UI, and also by downstream applications that consume trace data, such as MLflow evaluate. Args: span: The LiveSpan to add the attribute to tools: A list of standardized chat tool definitions (refer to the `spec <../llms/tracing/tracing-schema.html#chat-completion-spans>`_ for details) Example: .. code-block:: python :test: import mlflow from mlflow.tracing import set_span_chat_tools tools = [ { "type": "function", "function": { "name": "add", "description": "Add two numbers", "parameters": { "type": "object", "properties": { "a": {"type": "number"}, "b": {"type": "number"}, }, "required": ["a", "b"], }, }, } ] @mlflow.trace def f(): span = mlflow.get_current_active_span() set_span_chat_tools(span, tools) return 0 f() """ from mlflow.types.chat import ChatTool if not isinstance(tools, list): raise MlflowTracingException( f"Invalid tools type {type(tools)}. Expected a list of ChatTool.", error_code=BAD_REQUEST, ) sanitized_tools = [] for tool in tools: if isinstance(tool, dict): ChatTool.model_validate(tool) sanitized_tools.append(tool) elif isinstance(tool, ChatTool): sanitized_tools.append(tool.model_dump(exclude_unset=True)) span.set_attribute(SpanAttributeKey.CHAT_TOOLS, sanitized_tools)
def _calculate_percentile(sorted_data: list[float], percentile: float) -> float: """ Calculate the percentile value from sorted data. Args: sorted_data: A sorted list of numeric values percentile: The percentile to calculate (e.g., 0.25 for 25th percentile) Returns: The percentile value """ if not sorted_data: return 0.0 n = len(sorted_data) index = percentile * (n - 1) lower = int(index) upper = lower + 1 if upper >= n: return sorted_data[-1] # Linear interpolation between two nearest values weight = index - lower return sorted_data[lower] * (1 - weight) + sorted_data[upper] * weight def add_size_stats_to_trace_metadata(trace: Trace): """ Calculate the stats of trace and span sizes and add it as a metadata to the trace. This method modifies the trace object in place by adding a new tag. Note: For simplicity, we calculate the size without considering the size metadata itself. This provides a close approximation without requiring complex calculations. This function must not throw an exception. """ from mlflow.entities import Trace, TraceData try: span_sizes = [] for span in trace.data.spans: span_json = json.dumps(span.to_dict(), cls=TraceJSONEncoder) span_sizes.append(len(span_json.encode("utf-8"))) # NB: To compute the size of the total trace, we need to include the size of the # the trace info and the parent dicts for the spans. To avoid serializing spans # again (which can be expensive), we compute the size of the trace without spans # and combine it with the total size of the spans. empty_trace = Trace(info=trace.info, data=TraceData(spans=[])) metadata_size = len((empty_trace.to_json()).encode("utf-8")) # NB: the third term is the size of comma separators between spans (", "). trace_size_bytes = sum(span_sizes) + metadata_size + (len(span_sizes) - 1) * 2 # Sort span sizes for percentile calculation sorted_span_sizes = sorted(span_sizes) size_stats = { TraceSizeStatsKey.TOTAL_SIZE_BYTES: trace_size_bytes, TraceSizeStatsKey.NUM_SPANS: len(span_sizes), TraceSizeStatsKey.MAX_SPAN_SIZE_BYTES: max(span_sizes), TraceSizeStatsKey.P25_SPAN_SIZE_BYTES: int( _calculate_percentile(sorted_span_sizes, 0.25) ), TraceSizeStatsKey.P50_SPAN_SIZE_BYTES: int( _calculate_percentile(sorted_span_sizes, 0.50) ), TraceSizeStatsKey.P75_SPAN_SIZE_BYTES: int( _calculate_percentile(sorted_span_sizes, 0.75) ), } trace.info.trace_metadata[TraceMetadataKey.SIZE_STATS] = json.dumps(size_stats) # Keep the total size as a separate metadata for backward compatibility trace.info.trace_metadata[TraceMetadataKey.SIZE_BYTES] = str(trace_size_bytes) except Exception: _logger.warning("Failed to add size stats to trace metadata.", exc_info=True) def update_trace_state_from_span_conditionally(trace, root_span): """ Update trace state from span status, but only if the user hasn't explicitly set a different trace status. This utility preserves user-set trace status while maintaining default behavior for traces that haven't been explicitly configured. Used by trace processors when converting traces to an exportable state. Args: trace: The trace object to potentially update root_span: The root span whose status may be used to update the trace state """ from mlflow.entities.trace_state import TraceState # Only update trace state from span status if trace is still IN_PROGRESS # If the trace state is anything else, it means the user explicitly set it # and we should preserve it if trace.info.state == TraceState.IN_PROGRESS: state = TraceState.from_otel_status(root_span.status) # If the root span is created by the native OpenTelemetry SDK, the status code can be UNSET # (default value when an otel span is ended). Override it to OK here to avoid backend error. if state == TraceState.STATE_UNSPECIFIED: state = TraceState.OK trace.info.state = state def get_experiment_id_for_trace(span: OTelReadableSpan) -> str: """ Determine the experiment ID to associate with the trace. The experiment ID can be configured in multiple ways, in order of precedence: 1. An experiment ID specified via the span creation API i.e. MlflowClient().start_trace() 2. An experiment ID specified via `mlflow.tracing.set_destination` 3. An experiment ID of an active run. 4. The default experiment ID Args: span: The OpenTelemetry ReadableSpan to extract experiment ID from. Returns: The experiment ID string to use for the trace. """ from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION from mlflow.tracking.fluent import _get_experiment_id, _get_latest_active_run if experiment_id := get_otel_attribute(span, SpanAttributeKey.EXPERIMENT_ID): return experiment_id if destination := _MLFLOW_TRACE_USER_DESTINATION.get(): if exp_id := getattr(destination, "experiment_id", None): return exp_id if run := _get_latest_active_run(): return run.info.experiment_id return _get_experiment_id() def get_active_spans_table_name() -> str | None: """ Get active Unity Catalog spans table name that's set by `mlflow.tracing.set_destination`. """ from mlflow.entities.trace_location import UCSchemaLocation from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION if destination := _MLFLOW_TRACE_USER_DESTINATION.get(): if isinstance(destination, UCSchemaLocation): return destination.full_otel_spans_table_name return None def generate_assessment_id() -> str: """ Generates an assessment ID of the form 'a-<uuid4>' in hex string format. Returns: A unique identifier for an assessment that will be logged to a trace tag. """ id = uuid.uuid4().hex return f"{ASSESSMENT_ID_PREFIX}{id}" @contextmanager def _bypass_attribute_guard(span: OTelSpan) -> Generator[None, None, None]: """ OpenTelemetry does not allow setting attributes if the span has end time defined. https://github.com/open-telemetry/opentelemetry-python/blob/d327927d0274a320466feec6fba6d6ddb287dc5a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py#L849-L851 However, we need to set some attributes within `on_end` handler of the span processor, where the span is already marked as ended. This context manager is a hacky workaround to bypass the attribute guard. """ original_end_time = span._end_time span._end_time = None try: yield finally: span._end_time = original_end_time def parse_trace_id_v4(trace_id: str | None) -> tuple[str | None, str | None]: """ Parse the trace ID into location and trace ID components. """ if trace_id is None: return None, None if trace_id.startswith(TRACE_ID_V4_PREFIX): match trace_id.removeprefix(TRACE_ID_V4_PREFIX).split("/"): case [location, tid] if location and tid: return location, tid case _: raise MlflowException.invalid_parameter_value( f"Invalid trace ID format: {trace_id}. " f"Expected format: {TRACE_ID_V4_PREFIX}<location>/<trace_id>" ) return None, trace_id def construct_trace_id_v4(location: str, trace_id: str) -> str: """ Construct a trace ID for the given location and trace ID. """ return f"{TRACE_ID_V4_PREFIX}{location}/{trace_id}" def set_span_model_attribute(span: LiveSpan, inputs: dict[str, Any]) -> None: """ Set the model attribute on a span using parsed model information. This utility function extracts the model name from inputs and sets it as a span attribute. It's used by autologging implementations to consistently set model information across different LLM providers. Args: span: The LiveSpan to set the model attribute on inputs: The request inputs dictionary """ try: if (model := inputs.get("model")) and isinstance(model, str): span.set_attribute(SpanAttributeKey.MODEL, model) except Exception as e: _logger.debug(f"Failed to set model for {span}. Error: {e}") def set_span_cost_attribute(span: LiveSpan) -> None: """ Set the cost attribute on a span using calculated cost information. """ try: if cost := calculate_span_cost(span): span.set_attribute(SpanAttributeKey.LLM_COST, cost) except Exception as e: _logger.debug(f"Failed to set cost for {span}. Error: {e}")