Source code for mlflow.entities.trace_data

from collections import Counter
from dataclasses import dataclass, field
from typing import Any

from mlflow.entities import Span
from mlflow.tracing.constant import SpanAttributeKey
from mlflow.utils.annotations import deprecated


[docs]@dataclass class TraceData: """A container object that holds the spans data of a trace. Args: spans: List of spans that are part of the trace. """ spans: list[Span] = field(default_factory=list) # NB: Custom constructor to allow passing additional kwargs for backward compatibility for # DBX agent evaluator. Once they migrates to trace V3 schema, we can remove this. def __init__(self, spans: list[Span] | None = None, **kwargs): self.spans = spans or []
[docs] @classmethod def from_dict(cls, d): if not isinstance(d, dict): raise TypeError(f"TraceData.from_dict() expects a dictionary. Got: {type(d).__name__}") return cls(spans=[Span.from_dict(span) for span in d.get("spans", [])])
[docs] def to_dict(self) -> dict[str, Any]: return {"spans": [span.to_dict() for span in self.spans]}
# TODO: remove this property in 3.7.0 @property @deprecated(since="3.6.0", alternative="trace.search_spans(name=...)") def intermediate_outputs(self) -> dict[str, Any] | None: """ .. deprecated:: 3.6.0 Use `trace.search_spans(name=...)` to search for spans and get the outputs. Returns intermediate outputs produced by the model or agent while handling the request. There are mainly two flows to return intermediate outputs: 1. When a trace is generate by the `mlflow.log_trace` API, return `intermediate_outputs` attribute of the span. 2. When a trace is created normally with a tree of spans, aggregate the outputs of non-root spans. """ root_span = self._get_root_span() if root_span and root_span.get_attribute(SpanAttributeKey.INTERMEDIATE_OUTPUTS): return root_span.get_attribute(SpanAttributeKey.INTERMEDIATE_OUTPUTS) if len(self.spans) > 1: result = {} # spans may have duplicate names, so deduplicate the names by appending an index number. span_name_counter = Counter(span.name for span in self.spans) span_name_counter = {name: 1 for name, count in span_name_counter.items() if count > 1} for span in self.spans: span_name = span.name if count := span_name_counter.get(span_name): span_name_counter[span_name] += 1 span_name = f"{span_name}_{count}" if span.parent_id and span.outputs is not None: result[span_name] = span.outputs return result def _get_root_span(self) -> Span | None: for span in self.spans: if span.parent_id is None: return span # `request` and `response` are preserved for backward compatibility with v2 @property def request(self) -> str | None: if span := self._get_root_span(): # Accessing the OTel span directly get serialized value directly. return span._span.attributes.get(SpanAttributeKey.INPUTS) return None @property def response(self) -> str | None: if span := self._get_root_span(): # Accessing the OTel span directly get serialized value directly. return span._span.attributes.get(SpanAttributeKey.OUTPUTS) return None