Source code for mlflow.genai.judges.utils

from __future__ import annotations

import json
import logging
import re
import threading
import time
import traceback
import warnings
from contextlib import ContextDecorator
from dataclasses import asdict, dataclass, is_dataclass
from typing import TYPE_CHECKING, Any, NamedTuple

import pydantic
import requests

if TYPE_CHECKING:
    import litellm

    from mlflow.genai.judges.base import AlignmentOptimizer, JudgeField
    from mlflow.types.llm import ChatMessage, ToolCall

import mlflow
from mlflow.entities.assessment import Feedback
from mlflow.entities.assessment_source import AssessmentSource, AssessmentSourceType
from mlflow.entities.trace import Trace
from mlflow.environment_variables import MLFLOW_JUDGE_MAX_ITERATIONS
from mlflow.exceptions import MlflowException
from mlflow.genai.judges.constants import _DATABRICKS_DEFAULT_JUDGE_MODEL
from mlflow.genai.utils.enum_utils import StrEnum
from mlflow.protos.databricks_pb2 import (
    BAD_REQUEST,
    INVALID_PARAMETER_VALUE,
    REQUEST_LIMIT_EXCEEDED,
)
from mlflow.telemetry.events import InvokeCustomJudgeModelEvent
from mlflow.telemetry.track import record_usage_event
from mlflow.telemetry.utils import _is_in_databricks
from mlflow.utils.uri import is_databricks_uri
from mlflow.version import VERSION

_logger = logging.getLogger(__name__)

# "endpoints" is a special case for Databricks model serving endpoints.
_NATIVE_PROVIDERS = ["openai", "anthropic", "bedrock", "mistral", "endpoints"]

# Global cache to track model capabilities across function calls
# Key: model URI (e.g., "openai/gpt-4"), Value: boolean indicating response_format support
_MODEL_RESPONSE_FORMAT_CAPABILITIES: dict[str, bool] = {}


class DatabricksLLMJudgePrompts(NamedTuple):
    """Result of splitting ChatMessage list for Databricks API."""

    system_prompt: str | None
    user_prompt: str


def _check_databricks_agents_installed() -> None:
    """Check if databricks-agents is installed for databricks judge functionality.

    Raises:
        MlflowException: If databricks-agents is not installed.
    """
    try:
        import databricks.agents.evals  # noqa: F401
    except ImportError:
        raise MlflowException(
            f"To use '{_DATABRICKS_DEFAULT_JUDGE_MODEL}' as the judge model, the Databricks "
            "agents library must be installed. Please install it with: "
            "`pip install databricks-agents`",
            error_code=BAD_REQUEST,
        )


def get_default_model() -> str:
    if is_databricks_uri(mlflow.get_tracking_uri()):
        return _DATABRICKS_DEFAULT_JUDGE_MODEL
    else:
        return "openai:/gpt-4.1-mini"


def get_default_optimizer() -> AlignmentOptimizer:
    """
    Get the default alignment optimizer.

    Returns:
        A SIMBA alignment optimizer with no model specified (uses default model).
    """
    from mlflow.genai.judges.optimizers.simba import SIMBAAlignmentOptimizer

    return SIMBAAlignmentOptimizer()


def _is_litellm_available() -> bool:
    """Check if LiteLLM is available for import."""
    try:
        import litellm  # noqa: F401

        return True
    except ImportError:
        return False


def validate_judge_model(model_uri: str) -> None:
    """
    Validate that a judge model URI is valid and has required dependencies.

    This function performs early validation at judge construction time to provide
    fast feedback about configuration issues.

    Args:
        model_uri: The model URI to validate (e.g., "databricks", "openai:/gpt-4")

    Raises:
        MlflowException: If the model URI is invalid or required dependencies are missing.
    """
    from mlflow.metrics.genai.model_utils import _parse_model_uri

    # Special handling for Databricks default model
    if model_uri == _DATABRICKS_DEFAULT_JUDGE_MODEL:
        # Check if databricks-agents is available
        _check_databricks_agents_installed()
        return

    # Validate the URI format and extract provider
    provider, model_name = _parse_model_uri(model_uri)

    # Check if LiteLLM is required and available for non-native providers
    if provider not in _NATIVE_PROVIDERS:
        if not _is_litellm_available():
            raise MlflowException(
                f"LiteLLM is required for using '{provider}' as a provider. "
                "Please install it with: `pip install litellm`",
                error_code=INVALID_PARAMETER_VALUE,
            )


def format_prompt(prompt: str, **values) -> str:
    """Format double-curly variables in the prompt template."""
    for key, value in values.items():
        # Escape backslashes in the replacement string to prevent re.sub from interpreting
        # them as escape sequences (e.g. \u being treated as Unicode escape)
        replacement = str(value).replace("\\", "\\\\")
        prompt = re.sub(r"\{\{\s*" + key + r"\s*\}\}", replacement, prompt)
    return prompt


def add_output_format_instructions(prompt: str, output_fields: list["JudgeField"]) -> str:
    """
    Add structured output format instructions to a judge prompt.

    This ensures the LLM returns a JSON response with the expected fields,
    matching the expected format for the invoke_judge_model function.

    Args:
        prompt: The formatted prompt with template variables filled in
        output_fields: List of JudgeField objects defining output fields.

    Returns:
        The prompt with output format instructions appended
    """
    json_format_lines = []
    for field in output_fields:
        json_format_lines.append(f'    "{field.name}": "{field.description}"')

    json_format = "{\n" + ",\n".join(json_format_lines) + "\n}"

    output_format_instructions = f"""

Please provide your assessment in the following JSON format only (no markdown):

{json_format}"""
    return prompt + output_format_instructions


def _strip_markdown_code_blocks(response: str) -> str:
    """
    Strip markdown code blocks from LLM responses.

    Some legacy models wrap JSON responses in markdown code blocks (```json...```).
    This function removes those wrappers to extract the raw JSON content.

    Args:
        response: The raw response from the LLM

    Returns:
        The response with markdown code blocks removed
    """
    cleaned = response.strip()
    if not cleaned.startswith("```"):
        return cleaned

    lines = cleaned.split("\n")
    start_idx = 0
    end_idx = len(lines)

    for i, line in enumerate(lines):
        if i == 0 and line.startswith("```"):
            start_idx = 1
        elif line.strip() == "```" and i > 0:
            end_idx = i
            break

    return "\n".join(lines[start_idx:end_idx])


def _sanitize_justification(justification: str) -> str:
    # Some judge prompts instruct the model to think step by step.
    return justification.replace("Let's think step by step. ", "")


def _split_messages_for_databricks(messages: list["ChatMessage"]) -> DatabricksLLMJudgePrompts:
    """
    Split a list of ChatMessage objects into system and user prompts for Databricks API.

    Args:
        messages: List of ChatMessage objects to split.

    Returns:
        DatabricksLLMJudgePrompts namedtuple with system_prompt and user_prompt fields.
        The system_prompt may be None.

    Raises:
        MlflowException: If the messages list is empty or invalid.
    """
    from mlflow.types.llm import ChatMessage

    if not messages:
        raise MlflowException(
            "Invalid prompt format: expected non-empty list of ChatMessage",
            error_code=BAD_REQUEST,
        )

    system_prompt = None
    user_parts = []

    for msg in messages:
        if isinstance(msg, ChatMessage):
            if msg.role == "system":
                # Use the first system message as the actual system prompt for the API.
                # Any subsequent system messages are appended to the user prompt to preserve
                # their content and maintain the order in which they appear in the submitted
                # evaluation payload.
                if system_prompt is None:
                    system_prompt = msg.content
                else:
                    user_parts.append(f"System: {msg.content}")
            elif msg.role == "user":
                user_parts.append(msg.content)
            elif msg.role == "assistant":
                user_parts.append(f"Assistant: {msg.content}")

    user_prompt = "\n\n".join(user_parts) if user_parts else ""

    return DatabricksLLMJudgePrompts(system_prompt=system_prompt, user_prompt=user_prompt)


def _parse_databricks_judge_response(
    llm_output: str | None,
    assessment_name: str,
) -> Feedback:
    """
    Parse the response from Databricks judge into a Feedback object.

    Args:
        llm_output: Raw output from the LLM, or None if no response.
        assessment_name: Name of the assessment.

    Returns:
        Feedback object with parsed results or error.
    """
    source = AssessmentSource(
        source_type=AssessmentSourceType.LLM_JUDGE, source_id=_DATABRICKS_DEFAULT_JUDGE_MODEL
    )
    if not llm_output:
        return Feedback(
            name=assessment_name,
            error="Empty response from Databricks judge",
            source=source,
        )
    try:
        response_data = json.loads(llm_output)
    except json.JSONDecodeError as e:
        return Feedback(
            name=assessment_name,
            error=f"Invalid JSON response from Databricks judge: {e}",
            source=source,
        )
    if "result" not in response_data:
        return Feedback(
            name=assessment_name,
            error=f"Response missing 'result' field: {response_data}",
            source=source,
        )
    return Feedback(
        name=assessment_name,
        value=response_data["result"],
        rationale=response_data.get("rationale", ""),
        source=source,
    )


def call_chat_completions(user_prompt: str, system_prompt: str) -> Any:
    """
    Invokes the Databricks chat completions API using the databricks.agents.evals library.

    Args:
        user_prompt (str): The user prompt.
        system_prompt (str): The system prompt.

    Returns:
        The chat completions result.

    Raises:
        MlflowException: If databricks-agents is not installed.
    """
    _check_databricks_agents_installed()

    from databricks.rag_eval import context, env_vars

    env_vars.RAG_EVAL_EVAL_SESSION_CLIENT_NAME.set(f"mlflow-judge-optimizer-v{VERSION}")

    @context.eval_context
    def _call_chat_completions(user_prompt: str, system_prompt: str):
        managed_rag_client = context.get_context().build_managed_rag_client()

        return managed_rag_client.get_chat_completions_result(
            user_prompt=user_prompt,
            system_prompt=system_prompt,
        )

    return _call_chat_completions(user_prompt, system_prompt)


def _invoke_databricks_judge(
    prompt: str | list["ChatMessage"],
    assessment_name: str,
) -> Feedback:
    """
    Invoke the Databricks judge using the databricks.agents.evals library.

    Uses the direct chat completions API for clean prompt submission without
    any additional formatting or template requirements.

    Args:
        prompt: The formatted prompt with template variables filled in.
        assessment_name: The name of the assessment.

    Returns:
        Feedback object from the Databricks judge.

    Raises:
        MlflowException: If databricks-agents is not installed.
    """
    try:
        if isinstance(prompt, str):
            system_prompt = None
            user_prompt = prompt
        else:
            prompts = _split_messages_for_databricks(prompt)
            system_prompt = prompts.system_prompt
            user_prompt = prompts.user_prompt

        llm_result = call_chat_completions(user_prompt, system_prompt)
        return _parse_databricks_judge_response(llm_result.output, assessment_name)

    except Exception as e:
        _logger.debug(f"Failed to invoke Databricks judge: {e}")
        return Feedback(
            name=assessment_name,
            error=f"Failed to invoke Databricks judge: {e}",
            source=AssessmentSource(
                source_type=AssessmentSourceType.LLM_JUDGE,
                source_id=_DATABRICKS_DEFAULT_JUDGE_MODEL,
            ),
        )


def _invoke_via_gateway(
    model_uri: str,
    provider: str,
    prompt: str,
) -> str:
    """
    Invoke the judge model via native AI Gateway adapters.

    Args:
        model_uri: The full model URI.
        provider: The provider name.
        prompt: The prompt to evaluate.

    Returns:
        The JSON response string from the model.

    Raises:
        MlflowException: If the provider is not natively supported or invocation fails.
    """
    from mlflow.metrics.genai.model_utils import get_endpoint_type, score_model_on_payload

    if provider not in _NATIVE_PROVIDERS:
        raise MlflowException(
            f"LiteLLM is required for using '{provider}' LLM. Please install it with "
            "`pip install litellm`.",
            error_code=BAD_REQUEST,
        )

    return score_model_on_payload(
        model_uri=model_uri,
        payload=prompt,
        endpoint_type=get_endpoint_type(model_uri) or "llm/v1/chat",
    )


class FieldExtraction(pydantic.BaseModel):
    """Schema for extracting inputs and outputs from traces using LLM."""

    inputs: str = pydantic.Field(description="The user's original request or question")
    outputs: str = pydantic.Field(description="The system's final response")


@record_usage_event(InvokeCustomJudgeModelEvent)
def invoke_judge_model(
    model_uri: str,
    prompt: str | list["ChatMessage"],
    assessment_name: str,
    trace: Trace | None = None,
    num_retries: int = 10,
) -> Feedback:
    """
    Invoke the judge model.

    Routes to the appropriate implementation based on the model URI:
    - "databricks": Uses databricks.agents.evals library for default judge,
                    direct API for regular endpoints
    - LiteLLM-supported providers: Uses LiteLLM if available
    - Native providers: Falls back to AI Gateway adapters

    Args:
        model_uri: The model URI.
        prompt: The prompt to evaluate. Can be a string (single prompt) or
                a list of ChatMessage objects.
        assessment_name: The name of the assessment.
        trace: Optional trace object for context.
        num_retries: Number of retries on transient failures when using litellm.

    Returns:
        Feedback object with the judge's assessment.

    Raises:
        MlflowException: If the model cannot be invoked or dependencies are missing.
    """
    if model_uri == _DATABRICKS_DEFAULT_JUDGE_MODEL:
        return _invoke_databricks_judge(prompt, assessment_name)

    from mlflow.metrics.genai.model_utils import _parse_model_uri
    from mlflow.types.llm import ChatMessage

    model_provider, model_name = _parse_model_uri(model_uri)
    in_databricks = _is_in_databricks()

    # Handle Databricks endpoints (not the default judge) with proper telemetry
    if model_provider in {"databricks", "endpoints"} and isinstance(prompt, str):
        if model_provider == "endpoints":
            warnings.warn(
                "The legacy provider 'endpoints' is deprecated and will be removed in a future"
                "release. Please update your code to use the 'databricks' provider instead.",
                FutureWarning,
                stacklevel=2,
            )
        try:
            output = _invoke_databricks_judge_model(
                model_name=model_name,
                prompt=prompt,
                assessment_name=assessment_name,
                num_retries=num_retries,
            )
            feedback = output.feedback
            feedback.trace_id = trace.info.trace_id if trace is not None else None

            # Record success telemetry only when in Databricks
            if in_databricks:
                try:
                    _record_judge_model_usage_success_databricks_telemetry(
                        request_id=output.request_id,
                        model_provider=output.model_provider,
                        endpoint_name=output.model_name,
                        num_prompt_tokens=output.num_prompt_tokens,
                        num_completion_tokens=output.num_completion_tokens,
                    )
                except Exception as telemetry_error:
                    _logger.debug(
                        "Failed to record judge model usage success telemetry. Error: %s",
                        telemetry_error,
                    )

            return feedback

        except Exception:
            # Record failure telemetry only when in Databricks
            if in_databricks:
                try:
                    provider = "databricks" if model_provider == "endpoints" else model_provider
                    _record_judge_model_usage_failure_databricks_telemetry(
                        model_provider=provider,
                        endpoint_name=model_name,
                        error_code="UNKNOWN",
                        error_message=traceback.format_exc(),
                    )
                except Exception as telemetry_error:
                    _logger.debug(
                        "Failed to record judge model usage failure telemetry. Error: %s",
                        telemetry_error,
                    )
            raise

    # Handle all other cases (including non-Databricks, ChatMessage prompts, traces)
    messages = [ChatMessage(role="user", content=prompt)] if isinstance(prompt, str) else prompt

    if _is_litellm_available():
        response = _invoke_litellm_and_handle_tools(
            provider=model_provider,
            model_name=model_name,
            messages=messages,
            trace=trace,
            num_retries=num_retries,
        )
    elif trace is not None:
        raise MlflowException(
            "LiteLLM is required for using traces with judges. "
            "Please install it with `pip install litellm`.",
            error_code=BAD_REQUEST,
        )
    else:
        if not isinstance(prompt, str):
            raise MlflowException(
                "This judge is not supported by native LLM providers. Please install "
                "LiteLLM with `pip install litellm` to use this judge.",
                error_code=BAD_REQUEST,
            )
        response = _invoke_via_gateway(model_uri, model_provider, prompt)

    cleaned_response = _strip_markdown_code_blocks(response)

    try:
        response_dict = json.loads(cleaned_response)
    except json.JSONDecodeError as e:
        raise MlflowException(
            f"Failed to parse response from judge model. Response: {response}",
            error_code=BAD_REQUEST,
        ) from e

    feedback = Feedback(
        name=assessment_name,
        value=response_dict["result"],
        rationale=_sanitize_justification(response_dict.get("rationale", "")),
        source=AssessmentSource(source_type=AssessmentSourceType.LLM_JUDGE, source_id=model_uri),
        trace_id=trace.info.trace_id if trace is not None else None,
    )

    if "error" in response_dict:
        feedback.error = response_dict["error"]
        raise MlflowException(
            f"Judge evaluation failed with error: {response_dict['error']}",
            error_code=INVALID_PARAMETER_VALUE,
        )

    return feedback


def get_chat_completions_with_structured_output(
    model_uri: str,
    messages: list["ChatMessage"],
    output_schema: type[pydantic.BaseModel],
    trace: Trace | None = None,
    num_retries: int = 10,
) -> pydantic.BaseModel:
    """
    Get chat completions from an LLM with structured output conforming to a Pydantic schema.

    This function invokes an LLM and ensures the response matches the provided Pydantic schema.
    When a trace is provided, the LLM can use tool calling to examine trace spans.

    Args:
        model_uri: The model URI (e.g., "openai:/gpt-4", "anthropic:/claude-3").
        messages: List of ChatMessage objects for the conversation with the LLM.
        output_schema: Pydantic model class defining the expected output structure.
                       The LLM will be instructed to return data matching this schema.
        trace: Optional trace object for context. When provided, enables tool
               calling to examine trace spans.
        num_retries: Number of retries on transient failures. Defaults to 10 with
                     exponential backoff.

    Returns:
        Instance of output_schema with the structured data from the LLM.

    Raises:
        ImportError: If LiteLLM is not installed.
        JSONDecodeError: If the LLM response cannot be parsed as JSON.
        ValidationError: If the LLM response does not match the output schema.

    Example:
        .. code-block:: python

            from pydantic import BaseModel, Field
            from mlflow.genai.judges.utils import get_chat_completions_with_structured_output
            from mlflow.types.llm import ChatMessage


            class FieldExtraction(BaseModel):
                inputs: str = Field(description="The user's original request")
                outputs: str = Field(description="The system's final response")


            # Extract fields from a trace where root span lacks input/output
            # but nested spans contain the actual data
            result = get_chat_completions_with_structured_output(
                model_uri="openai:/gpt-4",
                messages=[
                    ChatMessage(role="system", content="Extract fields from the trace"),
                    ChatMessage(role="user", content="Find the inputs and outputs"),
                ],
                output_schema=FieldExtraction,
                trace=trace,  # Trace with nested spans containing actual data
            )
            print(result.inputs)  # Extracted from inner span
            print(result.outputs)  # Extracted from inner span
    """
    from mlflow.metrics.genai.model_utils import _parse_model_uri

    model_provider, model_name = _parse_model_uri(model_uri)

    response = _invoke_litellm_and_handle_tools(
        provider=model_provider,
        model_name=model_name,
        messages=messages,
        trace=trace,
        num_retries=num_retries,
        response_format=output_schema,
    )

    cleaned_response = _strip_markdown_code_blocks(response)
    response_dict = json.loads(cleaned_response)
    return output_schema(**response_dict)


@dataclass
class InvokeDatabricksModelOutput:
    response: str
    request_id: str | None
    num_prompt_tokens: int | None
    num_completion_tokens: int | None


def _parse_databricks_model_response(
    res_json: dict[str, Any], headers: dict[str, Any]
) -> InvokeDatabricksModelOutput:
    """
    Parse and validate the response from a Databricks model invocation.

    Args:
        res_json: The JSON response from the model
        headers: The response headers

    Returns:
        InvokeDatabricksModelOutput with parsed response data

    Raises:
        MlflowException: If the response structure is invalid
    """
    # Validate and extract choices
    choices = res_json.get("choices", [])
    if not choices:
        raise MlflowException(
            "Invalid response from Databricks model: missing 'choices' field",
            error_code=INVALID_PARAMETER_VALUE,
        )

    first_choice = choices[0]
    if "message" not in first_choice:
        raise MlflowException(
            "Invalid response from Databricks model: missing 'message' field",
            error_code=INVALID_PARAMETER_VALUE,
        )

    content = first_choice.get("message", {}).get("content")
    if content is None:
        raise MlflowException(
            "Invalid response from Databricks model: missing 'content' field",
            error_code=INVALID_PARAMETER_VALUE,
        )

    # Handle reasoning response (list of content items)
    if isinstance(content, list):
        text_content = None
        for item in content:
            if isinstance(item, dict) and item.get("type") == "text":
                text_content = item.get("text")
                break

        if text_content is None:
            raise MlflowException(
                "Invalid reasoning response: no text content found in response list",
                error_code=INVALID_PARAMETER_VALUE,
            )
        content = text_content

    usage = res_json.get("usage", {})

    return InvokeDatabricksModelOutput(
        response=content,
        request_id=headers.get("x-request-id"),
        num_prompt_tokens=usage.get("prompt_tokens"),
        num_completion_tokens=usage.get("completion_tokens"),
    )


def _invoke_databricks_model(
    *, model_name: str, prompt: str, num_retries: int
) -> InvokeDatabricksModelOutput:
    from mlflow.utils.databricks_utils import get_databricks_host_creds

    host_creds = get_databricks_host_creds()
    api_url = f"{host_creds.host}/serving-endpoints/{model_name}/invocations"

    # Implement retry logic with exponential backoff
    last_exception = None
    for attempt in range(num_retries + 1):
        try:
            res = requests.post(
                url=api_url,
                headers={"Authorization": f"Bearer {host_creds.token}"},
                json={
                    "messages": [
                        {
                            "role": "user",
                            "content": prompt,
                        }
                    ],
                },
            )
        except (requests.RequestException, requests.ConnectionError) as e:
            last_exception = e
            if attempt < num_retries:
                _logger.debug(f"Request attempt {attempt + 1} failed with error: {e}")
                time.sleep(2**attempt)  # Exponential backoff
                continue
            else:
                raise MlflowException(
                    f"Failed to invoke Databricks model after {num_retries + 1} attempts: {e}",
                    error_code=INVALID_PARAMETER_VALUE,
                ) from e

        # Check HTTP status before parsing JSON
        if res.status_code in [400, 401, 403, 404]:
            # Don't retry on bad request, unauthorized, not found, or forbidden
            raise MlflowException(
                f"Databricks model invocation failed with status {res.status_code}: {res.text}",
                error_code=INVALID_PARAMETER_VALUE,
            )

        if res.status_code >= 400:
            # For other errors, raise exception and potentially retry
            error_msg = (
                f"Databricks model invocation failed with status {res.status_code}: {res.text}"
            )
            if attempt < num_retries:
                # Log and retry for transient errors
                _logger.debug(f"Attempt {attempt + 1} failed: {error_msg}")
                time.sleep(2**attempt)  # Exponential backoff
                continue
            else:
                raise MlflowException(error_msg, error_code=INVALID_PARAMETER_VALUE)

        # Parse JSON response
        try:
            res_json = res.json()
        except json.JSONDecodeError as e:
            raise MlflowException(
                f"Failed to parse JSON response from Databricks model: {e}",
                error_code=INVALID_PARAMETER_VALUE,
            ) from e

        # Parse and validate the response using helper function
        return _parse_databricks_model_response(res_json, res.headers)

    # This should not be reached, but just in case
    if last_exception:
        raise MlflowException(
            f"Failed to invoke Databricks model: {last_exception}",
            error_code=INVALID_PARAMETER_VALUE,
        ) from last_exception


def _record_judge_model_usage_success_databricks_telemetry(
    *,
    request_id: str | None,
    model_provider: str,
    endpoint_name: str,
    num_prompt_tokens: int | None,
    num_completion_tokens: int | None,
) -> None:
    try:
        from databricks.agents.telemetry import record_judge_model_usage_success
    except ImportError:
        _logger.debug(
            "Failed to import databricks.agents.telemetry.record_judge_model_usage_success; "
            "databricks-agents needs to be installed."
        )
        return

    from mlflow.tracking.fluent import _get_experiment_id
    from mlflow.utils.databricks_utils import get_job_id, get_job_run_id, get_workspace_id

    record_judge_model_usage_success(
        request_id=request_id,
        experiment_id=_get_experiment_id(),
        job_id=get_job_id(),
        job_run_id=get_job_run_id(),
        workspace_id=get_workspace_id(),
        model_provider=model_provider,
        endpoint_name=endpoint_name,
        num_prompt_tokens=num_prompt_tokens,
        num_completion_tokens=num_completion_tokens,
    )


def _record_judge_model_usage_failure_databricks_telemetry(
    *,
    model_provider: str,
    endpoint_name: str,
    error_code: str,
    error_message: str,
) -> None:
    try:
        from databricks.agents.telemetry import record_judge_model_usage_failure
    except ImportError:
        _logger.debug(
            "Failed to import databricks.agents.telemetry.record_judge_model_usage_success; "
            "databricks-agents needs to be installed."
        )
        return

    from mlflow.tracking.fluent import _get_experiment_id
    from mlflow.utils.databricks_utils import get_job_id, get_job_run_id, get_workspace_id

    record_judge_model_usage_failure(
        experiment_id=_get_experiment_id(),
        job_id=get_job_id(),
        job_run_id=get_job_run_id(),
        workspace_id=get_workspace_id(),
        model_provider=model_provider,
        endpoint_name=endpoint_name,
        error_code=error_code,
        error_message=error_message,
    )


@dataclass
class InvokeJudgeModelHelperOutput:
    feedback: Feedback
    model_provider: str
    model_name: str
    request_id: str | None
    num_prompt_tokens: int | None
    num_completion_tokens: int | None


def _invoke_databricks_judge_model(
    *,
    model_name: str,
    prompt: str,
    assessment_name: str,
    num_retries: int = 10,
) -> InvokeJudgeModelHelperOutput:
    output = _invoke_databricks_model(
        model_name=model_name,
        prompt=prompt,
        num_retries=num_retries,
    )
    try:
        response_dict = json.loads(output.response)
        feedback = Feedback(
            name=assessment_name,
            value=response_dict["result"],
            rationale=_sanitize_justification(response_dict.get("rationale", "")),
            source=AssessmentSource(
                source_type=AssessmentSourceType.LLM_JUDGE,
                source_id=f"databricks:/{model_name}",
            ),
        )
    except json.JSONDecodeError as e:
        raise MlflowException(
            f"Failed to parse the response from the judge. Response: {output.response}",
            error_code=INVALID_PARAMETER_VALUE,
        ) from e

    return InvokeJudgeModelHelperOutput(
        feedback=feedback,
        model_provider="databricks",
        model_name=model_name,
        request_id=output.request_id,
        num_prompt_tokens=output.num_prompt_tokens,
        num_completion_tokens=output.num_completion_tokens,
    )


class _SuppressLiteLLMNonfatalErrors(ContextDecorator):
    """
    Thread-safe context manager and decorator to suppress LiteLLM's "Give Feedback" and
    "Provider List" messages. These messages indicate nonfatal bugs in the LiteLLM library;
    they are often noisy and can be safely ignored.

    Uses reference counting to ensure suppression remains active while any thread is running,
    preventing race conditions in parallel execution.
    """

    def __init__(self):
        self.lock = threading.RLock()
        self.count = 0
        self.original_litellm_settings = {}

    def __enter__(self) -> "_SuppressLiteLLMNonfatalErrors":
        try:
            import litellm
        except ImportError:
            return self

        with self.lock:
            if self.count == 0:
                # First caller - store original settings and enable suppression
                self.original_litellm_settings = {
                    "set_verbose": getattr(litellm, "set_verbose", None),
                    "suppress_debug_info": getattr(litellm, "suppress_debug_info", None),
                }
                litellm.set_verbose = False
                litellm.suppress_debug_info = True
            self.count += 1

        return self

    def __exit__(
        self,
        _exc_type: type[BaseException] | None,
        _exc_val: BaseException | None,
        _exc_tb: Any | None,
    ) -> bool:
        try:
            import litellm
        except ImportError:
            return False

        with self.lock:
            self.count -= 1
            if self.count == 0:
                # Last caller - restore original settings
                if (
                    original_verbose := self.original_litellm_settings.get("set_verbose")
                ) is not None:
                    litellm.set_verbose = original_verbose
                if (
                    original_suppress := self.original_litellm_settings.get("suppress_debug_info")
                ) is not None:
                    litellm.suppress_debug_info = original_suppress
                self.original_litellm_settings.clear()

        return False


# Global instance for use as threadsafe decorator
_suppress_litellm_nonfatal_errors = _SuppressLiteLLMNonfatalErrors()


def _invoke_litellm(
    litellm_model_uri: str,
    messages: list["litellm.Message"],
    tools: list[dict[str, Any]],
    num_retries: int,
    response_format: type[pydantic.BaseModel] | None,
    include_response_format: bool,
) -> "litellm.ModelResponse":
    """
    Invoke litellm completion with retry support.

    Args:
        litellm_model_uri: Full model URI for litellm (e.g., "openai/gpt-4").
        messages: List of litellm Message objects.
        tools: List of tool definitions (empty list if no tools).
        num_retries: Number of retries with exponential backoff.
        response_format: Optional Pydantic model class for structured output.
        include_response_format: Whether to include response_format in the request.

    Returns:
        The litellm ModelResponse object.

    Raises:
        Various litellm exceptions on failure.
    """
    import litellm

    kwargs = {
        "model": litellm_model_uri,
        "messages": messages,
        "tools": tools if tools else None,
        "tool_choice": "auto" if tools else None,
        "retry_policy": _get_litellm_retry_policy(num_retries),
        "retry_strategy": "exponential_backoff_retry",
        # In LiteLLM version 1.55.3+, max_retries is stacked on top of retry_policy.
        # To avoid double-retry, we set max_retries=0
        "max_retries": 0,
        # Drop any parameters that are known to be unsupported by the LLM.
        # This is important for compatibility with certain models that don't support
        # certain call parameters (e.g. GPT-4 doesn't support 'response_format')
        "drop_params": True,
    }
    if include_response_format:
        # LiteLLM supports passing Pydantic models directly for response_format
        kwargs["response_format"] = response_format or _get_default_judge_response_schema()
    return litellm.completion(**kwargs)


def _process_tool_calls(
    tool_calls: list["litellm.ChatCompletionMessageToolCall"],
    trace: Trace | None,
) -> list["litellm.Message"]:
    """
    Process tool calls and return tool response messages.

    Args:
        tool_calls: List of tool calls from the LLM response.
        trace: Optional trace object for context.

    Returns:
        List of litellm Message objects containing tool responses.
    """
    from mlflow.genai.judges.tools.registry import _judge_tool_registry

    tool_response_messages = []
    for tool_call in tool_calls:
        try:
            mlflow_tool_call = _create_mlflow_tool_call_from_litellm(litellm_tool_call=tool_call)
            result = _judge_tool_registry.invoke(tool_call=mlflow_tool_call, trace=trace)
        except Exception as e:
            tool_response_messages.append(
                _create_litellm_tool_response_message(
                    tool_call_id=tool_call.id,
                    tool_name=tool_call.function.name,
                    content=f"Error: {e!s}",
                )
            )
        else:
            if is_dataclass(result):
                result = asdict(result)
            result_json = json.dumps(result, default=str) if not isinstance(result, str) else result
            tool_response_messages.append(
                _create_litellm_tool_response_message(
                    tool_call_id=tool_call.id,
                    tool_name=tool_call.function.name,
                    content=result_json,
                )
            )
    return tool_response_messages


@_suppress_litellm_nonfatal_errors
def _invoke_litellm_and_handle_tools(
    provider: str,
    model_name: str,
    messages: list["ChatMessage"],
    trace: Trace | None,
    num_retries: int,
    response_format: type[pydantic.BaseModel] | None = None,
) -> str:
    """
    Invoke litellm with retry support and handle tool calling loop.

    Args:
        provider: The provider name (e.g., 'openai', 'anthropic').
        model_name: The model name.
        messages: List of ChatMessage objects.
        trace: Optional trace object for context with tool calling support.
        num_retries: Number of retries with exponential backoff on transient failures.
        response_format: Optional Pydantic model class for structured output format.
                       Used by get_chat_completions_with_structured_output for
                       schema-based extraction.

    Returns:
        The model's response content.

    Raises:
        MlflowException: If the request fails after all retries.
    """
    import litellm

    from mlflow.genai.judges.tools import list_judge_tools

    messages = [litellm.Message(role=msg.role, content=msg.content) for msg in messages]

    litellm_model_uri = f"{provider}/{model_name}"
    tools = []

    if trace is not None:
        judge_tools = list_judge_tools()
        tools = [tool.get_definition().to_dict() for tool in judge_tools]

    def _prune_messages_for_context_window():
        try:
            max_context_length = litellm.get_max_tokens(litellm_model_uri)
        except Exception:
            max_context_length = None

        return _prune_messages_exceeding_context_window_length(
            messages=messages,
            model=litellm_model_uri,
            max_tokens=max_context_length or 100000,
        )

    include_response_format = _MODEL_RESPONSE_FORMAT_CAPABILITIES.get(litellm_model_uri, True)

    max_iterations = MLFLOW_JUDGE_MAX_ITERATIONS.get()
    iteration_count = 0

    while True:
        iteration_count += 1
        if iteration_count > max_iterations:
            raise MlflowException(
                f"Completion iteration limit of {max_iterations} exceeded. "
                f"This usually indicates the model is not powerful enough to effectively "
                f"analyze the trace. Consider using a more intelligent/powerful model. "
                f"In rare cases, for very complex traces where a large number of completion "
                f"iterations might be required, you can increase the number of iterations by "
                f"modifying the {MLFLOW_JUDGE_MAX_ITERATIONS.name} environment variable.",
                error_code=REQUEST_LIMIT_EXCEEDED,
            )
        try:
            try:
                response = _invoke_litellm(
                    litellm_model_uri=litellm_model_uri,
                    messages=messages,
                    tools=tools,
                    num_retries=num_retries,
                    response_format=response_format,
                    include_response_format=include_response_format,
                )
            except (litellm.BadRequestError, litellm.UnsupportedParamsError) as e:
                if isinstance(e, litellm.ContextWindowExceededError) or "context length" in str(e):
                    messages = _prune_messages_for_context_window()
                    continue
                # Check whether the request attempted to use structured outputs, rather than
                # checking whether the model supports structured outputs in the capabilities cache,
                # since the capabilities cache may have been updated between the time that
                # include_response_format was set and the request was made
                if include_response_format:
                    # Retry without response_format if the request failed due to unsupported params.
                    # Some models don't support structured outputs (response_format) at all,
                    # and some models don't support both tool calling and structured outputs.
                    _logger.debug(
                        f"Model {litellm_model_uri} may not support structured outputs or combined "
                        f"tool calling + structured outputs. Error: {e}. "
                        f"Falling back to unstructured response."
                    )
                    _MODEL_RESPONSE_FORMAT_CAPABILITIES[litellm_model_uri] = False
                    include_response_format = False
                    continue
                else:
                    raise

            message = response.choices[0].message
            if not message.tool_calls:
                return message.content

            messages.append(message)
            tool_response_messages = _process_tool_calls(tool_calls=message.tool_calls, trace=trace)
            messages.extend(tool_response_messages)

        except MlflowException:
            raise
        except Exception as e:
            raise MlflowException(f"Failed to invoke the judge via litellm: {e}") from e


def _create_mlflow_tool_call_from_litellm(
    litellm_tool_call: "litellm.ChatCompletionMessageToolCall",
) -> "ToolCall":
    """
    Create an MLflow ToolCall from a LiteLLM tool call.

    Args:
        litellm_tool_call: The LiteLLM ChatCompletionMessageToolCall object.

    Returns:
        An MLflow ToolCall object.
    """
    from mlflow.types.llm import ToolCall

    return ToolCall(
        id=litellm_tool_call.id,
        function={
            "name": litellm_tool_call.function.name,
            "arguments": litellm_tool_call.function.arguments,
        },
    )


def _create_litellm_tool_response_message(
    tool_call_id: str, tool_name: str, content: str
) -> "litellm.Message":
    """
    Create a tool response message for LiteLLM.

    Args:
        tool_call_id: The ID of the tool call being responded to.
        tool_name: The name of the tool that was invoked.
        content: The content to include in the response.

    Returns:
        A litellm.Message object representing the tool response message.
    """
    import litellm

    return litellm.Message(
        tool_call_id=tool_call_id,
        role="tool",
        name=tool_name,
        content=content,
    )


def _get_default_judge_response_schema() -> type[pydantic.BaseModel]:
    """
    Get the default Pydantic schema for judge evaluations.

    Returns:
        A Pydantic BaseModel class defining the standard judge output format.
    """
    # Import here to avoid circular imports
    from mlflow.genai.judges.base import Judge

    output_fields = Judge.get_output_fields()

    field_definitions = {}
    for field in output_fields:
        field_definitions[field.name] = (str, pydantic.Field(description=field.description))

    return pydantic.create_model("JudgeEvaluation", **field_definitions)


def _prune_messages_exceeding_context_window_length(
    messages: list["litellm.Message"],
    model: str,
    max_tokens: int,
) -> list["litellm.Message"]:
    """
    Prune messages from history to stay under token limit.

    Args:
        messages: List of LiteLLM message objects.
        model: Model name for token counting.
        max_tokens: Maximum token limit.

    Returns:
        Pruned list of LiteLLM message objects under the token limit
    """
    import litellm

    initial_tokens = litellm.token_counter(model=model, messages=messages)
    if initial_tokens <= max_tokens:
        return messages

    pruned_messages = messages[:]
    # Remove tool call pairs until we're under limit
    while litellm.token_counter(model=model, messages=pruned_messages) > max_tokens:
        # Find first assistant message with tool calls
        assistant_msg = None
        assistant_idx = None
        for i, msg in enumerate(pruned_messages):
            if msg.role == "assistant" and msg.tool_calls:
                assistant_msg = msg
                assistant_idx = i
                break
        if assistant_msg is None:
            break  # No more tool calls to remove
        pruned_messages.pop(assistant_idx)
        # Remove corresponding tool response messages
        tool_call_ids = {
            tc.id if hasattr(tc, "id") else tc["id"] for tc in assistant_msg.tool_calls
        }
        pruned_messages = [
            msg
            for msg in pruned_messages
            if not (msg.role == "tool" and msg.tool_call_id in tool_call_ids)
        ]

    final_tokens = litellm.token_counter(model=model, messages=pruned_messages)
    _logger.info(f"Pruned message history from {initial_tokens} to {final_tokens} tokens")
    return pruned_messages


def _get_litellm_retry_policy(num_retries: int) -> "litellm.RetryPolicy":
    """
    Get a LiteLLM retry policy for retrying requests when transient API errors occur.

    Args:
        num_retries: The number of times to retry a request if it fails transiently due to
                     network error, rate limiting, etc. Requests are retried with exponential
                     backoff.

    Returns:
        A LiteLLM RetryPolicy instance.
    """
    from litellm import RetryPolicy

    return RetryPolicy(
        TimeoutErrorRetries=num_retries,
        RateLimitErrorRetries=num_retries,
        InternalServerErrorRetries=num_retries,
        ContentPolicyViolationErrorRetries=num_retries,
        # We don't retry on errors that are unlikely to be transient
        # (e.g. bad request, invalid auth credentials)
        BadRequestErrorRetries=0,
        AuthenticationErrorRetries=0,
    )


[docs]class CategoricalRating(StrEnum): """ A categorical rating for an assessment. Example: .. code-block:: python from mlflow.genai.judges import CategoricalRating from mlflow.entities import Feedback # Create feedback with categorical rating feedback = Feedback( name="my_metric", value=CategoricalRating.YES, rationale="The metric is passing." ) """ YES = "yes" NO = "no" UNKNOWN = "unknown" @classmethod def _missing_(cls, value: str): value = value.lower() for member in cls: if member == value: return member return cls.UNKNOWN