Source code for mlflow.entities.trace_info

import json
from dataclasses import dataclass, field
from typing import Any

from google.protobuf.duration_pb2 import Duration
from google.protobuf.json_format import MessageToDict
from google.protobuf.timestamp_pb2 import Timestamp

from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.entities.assessment import Assessment
from mlflow.entities.trace_location import TraceLocation
from mlflow.entities.trace_state import TraceState
from mlflow.entities.trace_status import TraceStatus
from mlflow.protos.databricks_tracing_pb2 import TraceInfo as ProtoTraceInfoV4
from mlflow.protos.service_pb2 import TraceInfoV3 as ProtoTraceInfoV3
from mlflow.tracing.constant import TraceMetadataKey


[docs]@dataclass class TraceInfo(_MlflowObject): """Metadata about a trace, such as its ID, location, timestamp, etc. Args: trace_id: The primary identifier for the trace. trace_location: The location where the trace is stored, represented as a :py:class:`~mlflow.entities.TraceLocation` object. MLflow currently support MLflow Experiment or Databricks Inference Table as a trace location. request_time: Start time of the trace, in milliseconds. state: State of the trace, represented as a :py:class:`~mlflow.entities.TraceState` enum. Can be one of [`OK`, `ERROR`, `IN_PROGRESS`, `STATE_UNSPECIFIED`]. request_preview: Request to the model/agent, equivalent to the input of the root, span but JSON-encoded and can be truncated. response_preview: Response from the model/agent, equivalent to the output of the root span but JSON-encoded and can be truncated. client_request_id: Client supplied request ID associated with the trace. This could be used to identify the trace/request from an external system that produced the trace, e.g., a session ID in a web application. execution_duration: Duration of the trace, in milliseconds. trace_metadata: Key-value pairs associated with the trace. They are designed for immutable values like run ID associated with the trace. tags: Tags associated with the trace. They are designed for mutable values, that can be updated after the trace is created via MLflow UI or API. assessments: List of assessments associated with the trace. """ trace_id: str trace_location: TraceLocation request_time: int state: TraceState request_preview: str | None = None response_preview: str | None = None client_request_id: str | None = None execution_duration: int | None = None trace_metadata: dict[str, str] = field(default_factory=dict) tags: dict[str, str] = field(default_factory=dict) assessments: list[Assessment] = field(default_factory=list)
[docs] def to_dict(self) -> dict[str, Any]: """Convert the TraceInfoV3 object to a dictionary.""" res = MessageToDict(self.to_proto(), preserving_proto_field_name=True) if self.execution_duration is not None: res.pop("execution_duration", None) res["execution_duration_ms"] = self.execution_duration # override trace_id to be the same as trace_info.trace_id since it's parsed # when converting to proto if it's v4 res["trace_id"] = self.trace_id return res
[docs] @classmethod def from_dict(cls, d: dict[str, Any]) -> "TraceInfo": """Create a TraceInfoV3 object from a dictionary.""" if "request_id" in d: from mlflow.entities.trace_info_v2 import TraceInfoV2 return TraceInfoV2.from_dict(d).to_v3() d = d.copy() if assessments := d.get("assessments"): d["assessments"] = [Assessment.from_dictionary(a) for a in assessments] if trace_location := d.get("trace_location"): d["trace_location"] = TraceLocation.from_dict(trace_location) if state := d.get("state"): d["state"] = TraceState(state) if request_time := d.get("request_time"): timestamp = Timestamp() timestamp.FromJsonString(request_time) d["request_time"] = timestamp.ToMilliseconds() if (execution_duration := d.pop("execution_duration_ms", None)) is not None: d["execution_duration"] = execution_duration return cls(**d)
[docs] def to_proto(self) -> ProtoTraceInfoV3 | ProtoTraceInfoV4: from mlflow.entities.trace_info_v2 import _truncate_request_metadata, _truncate_tags if self._is_v4(): from mlflow.utils.databricks_tracing_utils import trace_info_to_v4_proto return trace_info_to_v4_proto(self) request_time = Timestamp() request_time.FromMilliseconds(self.request_time) execution_duration = None if self.execution_duration is not None: execution_duration = Duration() execution_duration.FromMilliseconds(self.execution_duration) return ProtoTraceInfoV3( trace_id=self.trace_id, client_request_id=self.client_request_id, trace_location=self.trace_location.to_proto(), request_preview=self.request_preview, response_preview=self.response_preview, request_time=request_time, execution_duration=execution_duration, state=self.state.to_proto(), trace_metadata=_truncate_request_metadata(self.trace_metadata), tags=_truncate_tags(self.tags), assessments=[a.to_proto() for a in self.assessments], )
[docs] @classmethod def from_proto(cls, proto) -> "TraceInfo": if "request_id" in proto.DESCRIPTOR.fields_by_name: from mlflow.entities.trace_info_v2 import TraceInfoV2 return TraceInfoV2.from_proto(proto).to_v3() # import inside the function to avoid introducing top-level dependency on # mlflow.tracing.utils in entities module from mlflow.tracing.utils import construct_trace_id_v4 trace_location = TraceLocation.from_proto(proto.trace_location) if trace_location.uc_schema: trace_id = construct_trace_id_v4( location=f"{trace_location.uc_schema.catalog_name}.{trace_location.uc_schema.schema_name}", trace_id=proto.trace_id, ) else: trace_id = proto.trace_id return cls( trace_id=trace_id, client_request_id=( proto.client_request_id if proto.HasField("client_request_id") else None ), trace_location=trace_location, request_preview=proto.request_preview if proto.HasField("request_preview") else None, response_preview=proto.response_preview if proto.HasField("response_preview") else None, request_time=proto.request_time.ToMilliseconds(), execution_duration=( proto.execution_duration.ToMilliseconds() if proto.HasField("execution_duration") else None ), state=TraceState.from_proto(proto.state), trace_metadata=dict(proto.trace_metadata), tags=dict(proto.tags), assessments=[Assessment.from_proto(a) for a in proto.assessments], )
# Aliases for backward compatibility with V2 format @property def request_id(self) -> str: """Deprecated. Use `trace_id` instead.""" return self.trace_id @property def experiment_id(self) -> str | None: """ An MLflow experiment ID associated with the trace, if the trace is stored in MLflow tracking server. Otherwise, None. """ return ( self.trace_location.mlflow_experiment and self.trace_location.mlflow_experiment.experiment_id ) @experiment_id.setter def experiment_id(self, value: str | None) -> None: self.trace_location.mlflow_experiment.experiment_id = value @property def request_metadata(self) -> dict[str, str]: """Deprecated. Use `trace_metadata` instead.""" return self.trace_metadata @property def timestamp_ms(self) -> int: return self.request_time @timestamp_ms.setter def timestamp_ms(self, value: int) -> None: self.request_time = value @property def execution_time_ms(self) -> int | None: return self.execution_duration @execution_time_ms.setter def execution_time_ms(self, value: int | None) -> None: self.execution_duration = value @property def status(self) -> TraceStatus: """Deprecated. Use `state` instead.""" return TraceStatus.from_state(self.state) @status.setter def status(self, value: TraceStatus) -> None: self.state = value.to_state() @property def token_usage(self) -> dict[str, int] | None: """ Returns the aggregated token usage for the trace. Returns: A dictionary containing the aggregated LLM token usage for the trace. - "input_tokens": The total number of input tokens. - "output_tokens": The total number of output tokens. - "total_tokens": Sum of input and output tokens. .. note:: The token usage tracking is not supported for all LLM providers. Refer to the MLflow Tracing documentation for which providers support token usage tracking. """ if usage_json := self.trace_metadata.get(TraceMetadataKey.TOKEN_USAGE): return json.loads(usage_json) return None @property def cost(self) -> dict[str, float] | None: """ Returns the aggregated cost for the trace in USD. Returns: A dictionary containing the aggregated LLM cost for the trace. - "input_cost": The total cost for input tokens. - "output_cost": The total cost for output tokens. - "total_cost": Sum of input and output costs. .. note:: The cost tracking is calculated based on token usage and model pricing from LiteLLM. Cost tracking is not supported for all LLM providers. Refer to the MLflow Tracing documentation for which providers support cost tracking. """ if cost_json := self.trace_metadata.get(TraceMetadataKey.COST): return json.loads(cost_json) return None def _is_v4(self) -> bool: return self.trace_location.uc_schema is not None