# TODO: Split this file into multiple files and move under utils directory.
from __future__ import annotations
import inspect
import json
import logging
import uuid
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import asdict, is_dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Generator
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import ReadableSpan as OTelReadableSpan
from opentelemetry.sdk.trace import Span as OTelSpan
from mlflow.exceptions import BAD_REQUEST, MlflowException, MlflowTracingException
from mlflow.tracing.constant import (
ASSESSMENT_ID_PREFIX,
TRACE_ID_V4_PREFIX,
TRACE_REQUEST_ID_PREFIX,
SpanAttributeKey,
TokenUsageKey,
TraceMetadataKey,
TraceSizeStatsKey,
)
from mlflow.tracing.constant import (
CostKey as CostKey,
)
from mlflow.utils.mlflow_tags import IMMUTABLE_TAGS
from mlflow.version import IS_TRACING_SDK_ONLY
_logger = logging.getLogger(__name__)
SPANS_COLUMN_NAME = "spans"
if TYPE_CHECKING:
from mlflow.entities import LiveSpan, Trace
from mlflow.pyfunc.context import Context
from mlflow.types.chat import ChatTool
def capture_function_input_args(func, args, kwargs) -> dict[str, Any] | None:
try:
func_signature = inspect.signature(func)
bound_arguments = func_signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
# Remove `self` from bound arguments if it exists
if bound_arguments.arguments.get("self"):
del bound_arguments.arguments["self"]
# Remove `cls` from bound arguments if it's the first parameter and it's a type
# This detects classmethods more reliably
params = list(bound_arguments.arguments.keys())
if params and params[0] == "cls" and isinstance(bound_arguments.arguments["cls"], type):
del bound_arguments.arguments["cls"]
return bound_arguments.arguments
except Exception:
_logger.warning(f"Failed to capture inputs for function {func.__name__}.")
return None
class TraceJSONEncoder(json.JSONEncoder):
"""
Custom JSON encoder for serializing non-OpenTelemetry compatible objects in a trace or span.
Trace may contain types that require custom serialization logic, such as Pydantic models,
non-JSON-serializable types, etc.
"""
def default(self, obj):
try:
import pydantic
if isinstance(obj, pydantic.BaseModel):
return obj.model_dump()
except ImportError:
pass
# Some dataclass object defines __str__ method that doesn't return the full object
# representation, so we use dict representation instead.
# E.g. https://github.com/run-llama/llama_index/blob/29ece9b058f6b9a1cf29bc723ed4aa3a39879ad5/llama-index-core/llama_index/core/chat_engine/types.py#L63-L64
if is_dataclass(obj):
try:
return asdict(obj)
except TypeError:
pass
# Some object has dangerous side effect in __str__ method, so we use class name instead.
if not self._is_safe_to_encode_str(obj):
return type(obj)
try:
return super().default(obj)
except TypeError:
return str(obj)
def _is_safe_to_encode_str(self, obj) -> bool:
"""Check if it's safe to encode the object as a string."""
try:
# These Llama Index objects are not safe to encode as string, because their __str__
# method consumes the stream and make it unusable.
# E.g. https://github.com/run-llama/llama_index/blob/54f2da61ba8a573284ab8336f2b2810d948c3877/llama-index-core/llama_index/core/base/response/schema.py#L120-L127
from llama_index.core.base.response.schema import (
AsyncStreamingResponse,
StreamingResponse,
)
from llama_index.core.chat_engine.types import StreamingAgentChatResponse
if isinstance(
obj,
(AsyncStreamingResponse, StreamingResponse, StreamingAgentChatResponse),
):
return False
except ImportError:
pass
return True
def dump_span_attribute_value(value: Any) -> str:
# NB: OpenTelemetry attribute can store not only string but also a few primitives like
# int, float, bool, and list of them. However, we serialize all into JSON string here
# for the simplicity in deserialization process.
return json.dumps(value, cls=TraceJSONEncoder, ensure_ascii=False)
@lru_cache(maxsize=1)
def encode_span_id(span_id: int) -> str:
"""
Encode the given integer span ID to a 16-byte hex string.
# https://github.com/open-telemetry/opentelemetry-python/blob/9398f26ecad09e02ad044859334cd4c75299c3cd/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py#L507-L508
# NB: We don't add '0x' prefix to the hex string here for simpler parsing in backend.
# Some backend (e.g. Databricks) disallow this prefix.
"""
return trace_api.format_span_id(span_id)
@lru_cache(maxsize=1)
def encode_trace_id(trace_id: int) -> str:
"""
Encode the given integer trace ID to a 32-byte hex string.
"""
return trace_api.format_trace_id(trace_id)
def decode_id(span_or_trace_id: str) -> int:
"""
Decode the given hex string span or trace ID to an integer.
"""
return int(span_or_trace_id, 16)
def get_mlflow_span_for_otel_span(span: OTelSpan) -> LiveSpan | None:
"""
Get the active MLflow span for the given OpenTelemetry span.
"""
from mlflow.tracing.trace_manager import InMemoryTraceManager
trace_id = get_otel_attribute(span, SpanAttributeKey.REQUEST_ID)
mlflow_span_id = encode_span_id(span.get_span_context().span_id)
return InMemoryTraceManager.get_instance().get_span_from_id(trace_id, mlflow_span_id)
def build_otel_context(trace_id: int, span_id: int) -> trace_api.SpanContext:
"""
Build an OpenTelemetry SpanContext object from the given trace and span IDs.
"""
return trace_api.SpanContext(
trace_id=trace_id,
span_id=span_id,
# NB: This flag is OpenTelemetry's concept to indicate whether the context is
# propagated from remote parent or not. We don't support distributed tracing
# yet so always set it to False.
is_remote=False,
)
def _aggregate_from_spans(
spans: list[LiveSpan],
attribute_key: str,
input_key: str,
output_key: str,
total_key: str,
default: int | float,
) -> dict[str, int | float] | None:
"""Generic aggregation of data from spans using DFS traversal.
Avoids double-counting by skipping spans whose ancestors already have the data.
Args:
spans: List of spans to aggregate from.
attribute_key: The span attribute key to look up.
input_key: Key for extracting input value from span data.
output_key: Key for extracting output value from span data.
total_key: Key for extracting total value from span data.
default: Default value (0 for int, 0.0 for float) that also determines return type.
Returns:
Aggregated dictionary with the keys, or None if no data found.
"""
input_val = default
output_val = default
total_val = default
has_data = False
span_id_to_spans = {span.span_id: span for span in spans}
children_map: defaultdict[str, list[LiveSpan]] = defaultdict(list)
roots: list[LiveSpan] = []
for span in spans:
parent_id = span.parent_id
if parent_id and parent_id in span_id_to_spans:
children_map[parent_id].append(span)
else:
roots.append(span)
def dfs(span: LiveSpan, ancestor_has_data: bool) -> None:
nonlocal input_val, output_val, total_val, has_data
data = span.get_attribute(attribute_key)
span_has_data = data is not None
if span_has_data and not ancestor_has_data:
input_val += data.get(input_key, default)
output_val += data.get(output_key, default)
total_val += data.get(total_key, default)
has_data = True
next_ancestor_has_data = ancestor_has_data or span_has_data
for child in children_map.get(span.span_id, []):
dfs(child, next_ancestor_has_data)
for root in roots:
dfs(root, False)
if not has_data:
return None
return {
input_key: input_val,
output_key: output_val,
total_key: total_val,
}
def aggregate_usage_from_spans(spans: list[LiveSpan]) -> dict[str, int] | None:
"""Aggregate token usage information from all spans in the trace."""
return _aggregate_from_spans(
spans,
SpanAttributeKey.CHAT_USAGE,
TokenUsageKey.INPUT_TOKENS,
TokenUsageKey.OUTPUT_TOKENS,
TokenUsageKey.TOTAL_TOKENS,
0,
)
def aggregate_cost_from_spans(spans: list[LiveSpan]) -> dict[str, float] | None:
"""Aggregate cost information from all spans in the trace."""
return _aggregate_from_spans(
spans,
SpanAttributeKey.LLM_COST,
CostKey.INPUT_COST,
CostKey.OUTPUT_COST,
CostKey.TOTAL_COST,
0.0,
)
def calculate_span_cost(span: LiveSpan) -> dict[str, float] | None:
"""Calculate cost for a single span using LiteLLM pricing data.
Args:
span: The span to calculate cost for.
Returns:
Dictionary with input_cost, output_cost, and total_cost in USD,
or None if cost cannot be calculated.
"""
model_name = span.get_attribute(SpanAttributeKey.MODEL)
usage = span.get_attribute(SpanAttributeKey.CHAT_USAGE)
model_provider = span.get_attribute(SpanAttributeKey.MODEL_PROVIDER)
return calculate_cost_by_model_and_token_usage(model_name, usage, model_provider)
def calculate_cost_by_model_and_token_usage(
model_name: str | None, usage: dict[str, int] | None, model_provider: str | None = None
) -> dict[str, float] | None:
if not model_name or not usage:
return None
try:
from litellm import cost_per_token
except ImportError:
_logger.debug("LiteLLM not available for cost calculation")
return None
prompt_tokens = usage.get(TokenUsageKey.INPUT_TOKENS, 0)
completion_tokens = usage.get(TokenUsageKey.OUTPUT_TOKENS, 0)
if prompt_tokens == 0 and completion_tokens == 0:
return None
try:
input_cost_usd, output_cost_usd = cost_per_token(
model=model_name, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
)
except Exception as e:
if model_provider:
# pass model_provider only in exception case to avoid invalid model_provider
# being used when model_name itself is enough to calculate cost, since model_provider
# field can be with any value and litellm may not support it.
try:
input_cost_usd, output_cost_usd = cost_per_token(
model=model_name,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
custom_llm_provider=model_provider,
)
except Exception as e:
_logger.debug(
f"Failed to calculate cost for model {model_name}: {e}", exc_info=True
)
return None
else:
_logger.debug(
f"Failed to calculate cost for model {model_name} without provider: {e}",
exc_info=True,
)
return None
return {
CostKey.INPUT_COST: input_cost_usd,
CostKey.OUTPUT_COST: output_cost_usd,
CostKey.TOTAL_COST: input_cost_usd + output_cost_usd,
}
def get_otel_attribute(span: trace_api.Span, key: str) -> str | None:
"""
Get the attribute value from the OpenTelemetry span in a decoded format.
Args:
span: The OpenTelemetry span object.
key: The key of the attribute to retrieve.
Returns:
The attribute value as decoded string. If the attribute is not found or cannot
be parsed, return None.
"""
try:
attribute_value = span.attributes.get(key)
if attribute_value is None:
return None
return json.loads(attribute_value)
except Exception:
_logger.debug(f"Failed to get attribute {key} with from span {span}.", exc_info=True)
def _try_get_prediction_context():
# NB: Tracing is enabled in mlflow-skinny, but the pyfunc module cannot be imported as it
# relies on numpy, which is not installed in skinny.
try:
from mlflow.pyfunc.context import get_prediction_context
except (ImportError, KeyError):
return
return get_prediction_context()
def maybe_get_request_id(is_evaluate=False) -> str | None:
"""Get the request ID if the current prediction is as a part of MLflow model evaluation."""
context = _try_get_prediction_context()
if not context or (is_evaluate and not context.is_evaluate):
return None
if not context.request_id and is_evaluate:
_logger.warning(
f"Missing request_id for context {context}. request_id can't be None when "
"is_evaluate=True. This is likely an internal error of MLflow, please file "
"a bug report at https://github.com/mlflow/mlflow/issues."
)
return None
return context.request_id
def maybe_get_dependencies_schemas() -> dict[str, Any] | None:
if context := _try_get_prediction_context():
return context.dependencies_schemas
def maybe_get_logged_model_id() -> str | None:
"""
Get the logged model ID associated with the current prediction context.
"""
if context := _try_get_prediction_context():
return context.model_id
def exclude_immutable_tags(tags: dict[str, str]) -> dict[str, str]:
"""Exclude immutable tags e.g. "mlflow.user" from the given tags."""
return {k: v for k, v in tags.items() if k not in IMMUTABLE_TAGS}
def generate_mlflow_trace_id_from_otel_trace_id(otel_trace_id: int) -> str:
"""
Generate an MLflow trace ID from an OpenTelemetry trace ID.
Args:
otel_trace_id: The OpenTelemetry trace ID as an integer.
Returns:
The MLflow trace ID string in format "tr-<hex_trace_id>".
"""
return TRACE_REQUEST_ID_PREFIX + encode_trace_id(otel_trace_id)
def generate_trace_id_v4_from_otel_trace_id(otel_trace_id: int, location: str) -> str:
"""
Generate a trace ID in v4 format from the given OpenTelemetry trace ID.
Args:
otel_trace_id: The OpenTelemetry trace ID as an integer.
location: The location, of the trace.
Returns:
The MLflow trace ID string in format "trace:/<location>/<hex_trace_id>".
"""
return construct_trace_id_v4(location, encode_trace_id(otel_trace_id))
def generate_trace_id_v4(span: OTelSpan, location: str) -> str:
"""
Generate a trace ID for the given span.
Args:
span: The OpenTelemetry span object.
location: The location, of the trace.
Returns:
Trace ID with format "trace:/<location>/<hex_trace_id>".
"""
return generate_trace_id_v4_from_otel_trace_id(span.context.trace_id, location)
def generate_trace_id_v3(span: OTelSpan) -> str:
"""
Generate a trace ID for the given span (V3 trace schema).
The format will be "tr-<trace_id>" where the trace_id is hex-encoded Otel trace ID.
"""
return generate_mlflow_trace_id_from_otel_trace_id(span.context.trace_id)
def generate_request_id_v2() -> str:
"""
Generate a request ID for the given span.
This should only be used for V2 trace schema where we use a random UUID as
request ID. In the V3 schema, "request_id" is renamed to "trace_id" and
we use the otel-generated trace ID with encoding.
"""
return uuid.uuid4().hex
def construct_full_inputs(func, *args, **kwargs) -> dict[str, Any]:
"""
Construct the full input arguments dictionary for the given function,
including positional and keyword arguments.
"""
signature = inspect.signature(func)
# this does not create copy. So values should not be mutated directly
arguments = signature.bind_partial(*args, **kwargs).arguments
if "self" in arguments:
arguments.pop("self")
return arguments
@contextmanager
def maybe_set_prediction_context(context: "Context" | None):
"""
Set the prediction context if the given context
is not None. Otherwise no-op.
"""
if not IS_TRACING_SDK_ONLY and context:
from mlflow.pyfunc.context import set_prediction_context
with set_prediction_context(context):
yield
else:
yield
def _calculate_percentile(sorted_data: list[float], percentile: float) -> float:
"""
Calculate the percentile value from sorted data.
Args:
sorted_data: A sorted list of numeric values
percentile: The percentile to calculate (e.g., 0.25 for 25th percentile)
Returns:
The percentile value
"""
if not sorted_data:
return 0.0
n = len(sorted_data)
index = percentile * (n - 1)
lower = int(index)
upper = lower + 1
if upper >= n:
return sorted_data[-1]
# Linear interpolation between two nearest values
weight = index - lower
return sorted_data[lower] * (1 - weight) + sorted_data[upper] * weight
def add_size_stats_to_trace_metadata(trace: Trace):
"""
Calculate the stats of trace and span sizes and add it as a metadata to the trace.
This method modifies the trace object in place by adding a new tag.
Note: For simplicity, we calculate the size without considering the size metadata itself.
This provides a close approximation without requiring complex calculations.
This function must not throw an exception.
"""
from mlflow.entities import Trace, TraceData
try:
span_sizes = []
for span in trace.data.spans:
span_json = json.dumps(span.to_dict(), cls=TraceJSONEncoder)
span_sizes.append(len(span_json.encode("utf-8")))
# NB: To compute the size of the total trace, we need to include the size of the
# the trace info and the parent dicts for the spans. To avoid serializing spans
# again (which can be expensive), we compute the size of the trace without spans
# and combine it with the total size of the spans.
empty_trace = Trace(info=trace.info, data=TraceData(spans=[]))
metadata_size = len((empty_trace.to_json()).encode("utf-8"))
# NB: the third term is the size of comma separators between spans (", ").
trace_size_bytes = sum(span_sizes) + metadata_size + (len(span_sizes) - 1) * 2
# Sort span sizes for percentile calculation
sorted_span_sizes = sorted(span_sizes)
size_stats = {
TraceSizeStatsKey.TOTAL_SIZE_BYTES: trace_size_bytes,
TraceSizeStatsKey.NUM_SPANS: len(span_sizes),
TraceSizeStatsKey.MAX_SPAN_SIZE_BYTES: max(span_sizes),
TraceSizeStatsKey.P25_SPAN_SIZE_BYTES: int(
_calculate_percentile(sorted_span_sizes, 0.25)
),
TraceSizeStatsKey.P50_SPAN_SIZE_BYTES: int(
_calculate_percentile(sorted_span_sizes, 0.50)
),
TraceSizeStatsKey.P75_SPAN_SIZE_BYTES: int(
_calculate_percentile(sorted_span_sizes, 0.75)
),
}
trace.info.trace_metadata[TraceMetadataKey.SIZE_STATS] = json.dumps(size_stats)
# Keep the total size as a separate metadata for backward compatibility
trace.info.trace_metadata[TraceMetadataKey.SIZE_BYTES] = str(trace_size_bytes)
except Exception:
_logger.warning("Failed to add size stats to trace metadata.", exc_info=True)
def update_trace_state_from_span_conditionally(trace, root_span):
"""
Update trace state from span status, but only if the user hasn't explicitly set
a different trace status.
This utility preserves user-set trace status while maintaining default behavior
for traces that haven't been explicitly configured. Used by trace processors when
converting traces to an exportable state.
Args:
trace: The trace object to potentially update
root_span: The root span whose status may be used to update the trace state
"""
from mlflow.entities.trace_state import TraceState
# Only update trace state from span status if trace is still IN_PROGRESS
# If the trace state is anything else, it means the user explicitly set it
# and we should preserve it
if trace.info.state == TraceState.IN_PROGRESS:
state = TraceState.from_otel_status(root_span.status)
# If the root span is created by the native OpenTelemetry SDK, the status code can be UNSET
# (default value when an otel span is ended). Override it to OK here to avoid backend error.
if state == TraceState.STATE_UNSPECIFIED:
state = TraceState.OK
trace.info.state = state
def get_experiment_id_for_trace(span: OTelReadableSpan) -> str:
"""
Determine the experiment ID to associate with the trace.
The experiment ID can be configured in multiple ways, in order of precedence:
1. An experiment ID specified via the span creation API i.e. MlflowClient().start_trace()
2. An experiment ID specified via `mlflow.tracing.set_destination`
3. An experiment ID of an active run.
4. The default experiment ID
Args:
span: The OpenTelemetry ReadableSpan to extract experiment ID from.
Returns:
The experiment ID string to use for the trace.
"""
from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION
from mlflow.tracking.fluent import _get_experiment_id, _get_latest_active_run
if experiment_id := get_otel_attribute(span, SpanAttributeKey.EXPERIMENT_ID):
return experiment_id
if destination := _MLFLOW_TRACE_USER_DESTINATION.get():
if exp_id := getattr(destination, "experiment_id", None):
return exp_id
if run := _get_latest_active_run():
return run.info.experiment_id
return _get_experiment_id()
def get_active_spans_table_name() -> str | None:
"""
Get active Unity Catalog spans table name that's set by `mlflow.tracing.set_destination`.
"""
from mlflow.entities.trace_location import UCSchemaLocation
from mlflow.tracing.provider import _MLFLOW_TRACE_USER_DESTINATION
if destination := _MLFLOW_TRACE_USER_DESTINATION.get():
if isinstance(destination, UCSchemaLocation):
return destination.full_otel_spans_table_name
return None
def generate_assessment_id() -> str:
"""
Generates an assessment ID of the form 'a-<uuid4>' in hex string format.
Returns:
A unique identifier for an assessment that will be logged to a trace tag.
"""
id = uuid.uuid4().hex
return f"{ASSESSMENT_ID_PREFIX}{id}"
@contextmanager
def _bypass_attribute_guard(span: OTelSpan) -> Generator[None, None, None]:
"""
OpenTelemetry does not allow setting attributes if the span has end time defined.
https://github.com/open-telemetry/opentelemetry-python/blob/d327927d0274a320466feec6fba6d6ddb287dc5a/opentelemetry-sdk/src/opentelemetry/sdk/trace/__init__.py#L849-L851
However, we need to set some attributes within `on_end` handler of the span processor,
where the span is already marked as ended. This context manager is a hacky workaround
to bypass the attribute guard.
"""
original_end_time = span._end_time
span._end_time = None
try:
yield
finally:
span._end_time = original_end_time
def parse_trace_id_v4(trace_id: str | None) -> tuple[str | None, str | None]:
"""
Parse the trace ID into location and trace ID components.
"""
if trace_id is None:
return None, None
if trace_id.startswith(TRACE_ID_V4_PREFIX):
match trace_id.removeprefix(TRACE_ID_V4_PREFIX).split("/"):
case [location, tid] if location and tid:
return location, tid
case _:
raise MlflowException.invalid_parameter_value(
f"Invalid trace ID format: {trace_id}. "
f"Expected format: {TRACE_ID_V4_PREFIX}<location>/<trace_id>"
)
return None, trace_id
def construct_trace_id_v4(location: str, trace_id: str) -> str:
"""
Construct a trace ID for the given location and trace ID.
"""
return f"{TRACE_ID_V4_PREFIX}{location}/{trace_id}"
def set_span_model_attribute(span: LiveSpan, inputs: dict[str, Any]) -> None:
"""
Set the model attribute on a span using parsed model information.
This utility function extracts the model name from inputs and
sets it as a span attribute. It's used by autologging implementations to
consistently set model information across different LLM providers.
Args:
span: The LiveSpan to set the model attribute on
inputs: The request inputs dictionary
"""
try:
if (model := inputs.get("model")) and isinstance(model, str):
span.set_attribute(SpanAttributeKey.MODEL, model)
except Exception as e:
_logger.debug(f"Failed to set model for {span}. Error: {e}")
def set_span_cost_attribute(span: LiveSpan) -> None:
"""
Set the cost attribute on a span using calculated cost information.
"""
try:
if cost := calculate_span_cost(span):
span.set_attribute(SpanAttributeKey.LLM_COST, cost)
except Exception as e:
_logger.debug(f"Failed to set cost for {span}. Error: {e}")