from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass
from typing import Any, Optional, Union
from mlflow.entities._mlflow_object import _MlflowObject
from mlflow.entities.span import Span, SpanType
from mlflow.entities.trace_data import TraceData
from mlflow.entities.trace_info import TraceInfo
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.protos.databricks_trace_server_pb2 import Trace as ProtoTrace
from mlflow.protos.databricks_trace_server_pb2 import TraceData as ProtoTraceData
_logger = logging.getLogger(__name__)
[docs]@dataclass
class Trace(_MlflowObject):
    """A trace object.
    Args:
        info: A lightweight object that contains the metadata of a trace.
        data: A container object that holds the spans data of a trace.
    """
    info: TraceInfo
    data: TraceData
    def __repr__(self) -> str:
        return f"Trace(request_id={self.info.request_id})"
[docs]    def to_dict(self) -> dict[str, Any]:
        return {"info": self.info.to_dict(), "data": self.data.to_dict()} 
[docs]    def to_json(self, pretty=False) -> str:
        from mlflow.tracing.utils import TraceJSONEncoder
        return json.dumps(self.to_dict(), cls=TraceJSONEncoder, indent=2 if pretty else None) 
[docs]    @classmethod
    def from_dict(cls, trace_dict: dict[str, Any]) -> Trace:
        info = trace_dict.get("info")
        data = trace_dict.get("data")
        if info is None or data is None:
            raise MlflowException(
                "Unable to parse Trace from dictionary. Expected keys: 'info' and 'data'. "
                f"Received keys: {list(trace_dict.keys())}",
                error_code=INVALID_PARAMETER_VALUE,
            )
        return cls(
            info=TraceInfo.from_dict(info),
            data=TraceData.from_dict(data),
        ) 
[docs]    @classmethod
    def from_json(cls, trace_json: str) -> Trace:
        try:
            trace_dict = json.loads(trace_json)
        except json.JSONDecodeError as e:
            raise MlflowException(
                f"Unable to parse trace JSON: {trace_json}. Error: {e}",
                error_code=INVALID_PARAMETER_VALUE,
            )
        return cls.from_dict(trace_dict) 
    def _serialize_for_mimebundle(self):
        # databricks notebooks will use the request ID to
        # fetch the trace from the backend. including the
        # full JSON can cause notebooks to exceed size limits
        return json.dumps(self.info.request_id)
    def _repr_mimebundle_(self, include=None, exclude=None):
        """
        This method is used to trigger custom display logic in IPython notebooks.
        See https://ipython.readthedocs.io/en/stable/config/integrating.html#MyObject
        for more details.
        At the moment, the only supported MIME type is "application/databricks.mlflow.trace",
        which contains a JSON representation of the Trace object. This object is deserialized
        in Databricks notebooks to display the Trace object in a nicer UI.
        """
        from mlflow.tracing.display import (
            get_display_handler,
            get_notebook_iframe_html,
            is_using_tracking_server,
        )
        from mlflow.utils.databricks_utils import is_in_databricks_runtime
        bundle = {"text/plain": repr(self)}
        if not get_display_handler().disabled:
            if is_in_databricks_runtime():
                bundle["application/databricks.mlflow.trace"] = self._serialize_for_mimebundle()
            elif is_using_tracking_server():
                bundle["text/html"] = get_notebook_iframe_html([self])
        return bundle
[docs]    def to_pandas_dataframe_row(self) -> dict[str, Any]:
        return {
            "request_id": self.info.request_id,
            "trace": self,
            "timestamp_ms": self.info.timestamp_ms,
            "status": self.info.status,
            "execution_time_ms": self.info.execution_time_ms,
            "request": self._deserialize_json_attr(self.data.request),
            "response": self._deserialize_json_attr(self.data.response),
            "request_metadata": self.info.request_metadata,
            "spans": [span.to_dict() for span in self.data.spans],
            "tags": self.info.tags,
            "assessments": self.info.assessments,
        } 
    def _deserialize_json_attr(self, value: str):
        try:
            return json.loads(value)
        except Exception:
            _logger.debug(f"Failed to deserialize JSON attribute: {value}", exc_info=True)
            return value
[docs]    def search_spans(
        self, span_type: Optional[SpanType] = None, name: Optional[Union[str, re.Pattern]] = None
    ) -> list[Span]:
        """
        Search for spans that match the given criteria within the trace.
        Args:
            span_type: The type of the span to search for.
            name: The name of the span to search for. This can be a string or a regular expression.
        Returns:
            A list of spans that match the given criteria.
            If there is no match, an empty list is returned.
        .. code-block:: python
            import mlflow
            import re
            from mlflow.entities import SpanType
            @mlflow.trace(span_type=SpanType.CHAIN)
            def run(x: int) -> int:
                x = add_one(x)
                x = add_two(x)
                x = multiply_by_two(x)
                return x
            @mlflow.trace(span_type=SpanType.TOOL)
            def add_one(x: int) -> int:
                return x + 1
            @mlflow.trace(span_type=SpanType.TOOL)
            def add_two(x: int) -> int:
                return x + 2
            @mlflow.trace(span_type=SpanType.TOOL)
            def multiply_by_two(x: int) -> int:
                return x * 2
            # Run the function and get the trace
            y = run(2)
            trace = mlflow.get_last_active_trace()
            # 1. Search spans by name (exact match)
            spans = trace.search_spans(name="add_one")
            print(spans)
            # Output: [Span(name='add_one', ...)]
            # 2. Search spans by name (regular expression)
            pattern = re.compile(r"add.*")
            spans = trace.search_spans(name=pattern)
            print(spans)
            # Output: [Span(name='add_one', ...), Span(name='add_two', ...)]
            # 3. Search spans by type
            spans = trace.search_spans(span_type=SpanType.LLM)
            print(spans)
            # Output: [Span(name='run', ...)]
            # 4. Search spans by name and type
            spans = trace.search_spans(name="add_one", span_type=SpanType.TOOL)
            print(spans)
            # Output: [Span(name='add_one', ...)]
        """
        def _match_name(span: Span) -> bool:
            if isinstance(name, str):
                return span.name == name
            elif isinstance(name, re.Pattern):
                return name.search(span.name) is not None
            elif name is None:
                return True
            else:
                raise MlflowException(
                    f"Invalid type for 'name'. Expected str or re.Pattern. Got: {type(name)}",
                    error_code=INVALID_PARAMETER_VALUE,
                )
        def _match_type(span: Span) -> bool:
            if isinstance(span_type, str):
                return span.span_type == span_type
            elif span_type is None:
                return True
            else:
                raise MlflowException(
                    "Invalid type for 'span_type'. Expected str or mlflow.entities.SpanType. "
                    f"Got: {type(span_type)}",
                    error_code=INVALID_PARAMETER_VALUE,
                )
        return [span for span in self.data.spans if _match_name(span) and _match_type(span)] 
[docs]    @staticmethod
    def pandas_dataframe_columns() -> list[str]:
        return [
            "request_id",
            "trace",
            "timestamp_ms",
            "status",
            "execution_time_ms",
            "request",
            "response",
            "request_metadata",
            "spans",
            "tags",
            "assessments",
        ] 
[docs]    def to_proto(self):
        """Convert into a proto object to sent to the Databricks Trace Server."""
        return ProtoTrace(
            # Convert MLflow's TraceInfoV3 to Databricks Trace Server's TraceInfo
            info=self.info.to_v3_proto(self.data.request, self.data.response),
            data=ProtoTraceData(spans=[span.to_proto() for span in self.data.spans]),
        )