Source code for mlflow.genai.scorers.ragas

"""
RAGAS integration for MLflow.

This module provides integration with RAGAS metrics, allowing them to be used
with MLflow's judge interface.

Example usage:

.. code-block:: python

    from mlflow.genai.scorers.ragas import get_scorer

    judge = get_scorer("Faithfulness", model="openai:/gpt-4")
    feedback = judge(
        inputs="What is MLflow?", outputs="MLflow is a platform...", trace=trace
    )
"""

from __future__ import annotations

import inspect
import logging
import re
from typing import Any

from pydantic import PrivateAttr
from ragas.dataset_schema import MultiTurnSample, SingleTurnSample
from ragas.llms import BaseRagasLLM

from mlflow.entities.assessment import Feedback
from mlflow.entities.assessment_source import AssessmentSource, AssessmentSourceType
from mlflow.entities.trace import Trace
from mlflow.exceptions import MlflowException
from mlflow.genai.judges.builtin import _MODEL_API_DOC
from mlflow.genai.judges.utils import CategoricalRating, get_default_model
from mlflow.genai.scorers import FRAMEWORK_METADATA_KEY
from mlflow.genai.scorers.base import Scorer, ScorerKind
from mlflow.genai.scorers.ragas.models import (
    create_default_embeddings,
    create_ragas_model,
)
from mlflow.genai.scorers.ragas.registry import (
    get_metric_class,
    is_agentic_or_multiturn_metric,
    requires_args_from_placeholders,
    requires_embeddings,
    requires_llm_at_score_time,
    requires_llm_in_constructor,
)
from mlflow.genai.scorers.ragas.utils import (
    create_mlflow_error_message_from_ragas_param,
    map_scorer_inputs_to_ragas_sample,
)
from mlflow.genai.utils.trace_utils import _wrap_async_predict_fn
from mlflow.utils.annotations import experimental
from mlflow.utils.docstring_utils import format_docstring

_logger = logging.getLogger(__name__)


@experimental(version="3.8.0")
@format_docstring(_MODEL_API_DOC)
class RagasScorer(Scorer):
    """
    Initialize a RAGAS metric scorer.

    Args:
        metric_name: Name of the RAGAS metric (e.g., "Faithfulness")
        model: {{ model }}
        metric_kwargs: Additional metric-specific parameters
    """

    _metric: Any = PrivateAttr()
    _is_deterministic: bool = PrivateAttr(default=False)
    _model: str = PrivateAttr()
    _llm: BaseRagasLLM | None = PrivateAttr(default=None)

    def __init__(
        self,
        metric_name: str | None = None,
        model: str | None = None,
        **metric_kwargs,
    ):
        if metric_name is None:
            metric_name = self.metric_name

        self._validate_args(metric_name, model)
        super().__init__(name=metric_name)
        model = model or get_default_model()
        self._model = model
        metric_class = get_metric_class(metric_name)
        ragas_llm = create_ragas_model(model)
        constructor_kwargs = dict(metric_kwargs)

        if requires_llm_in_constructor(metric_name):
            constructor_kwargs["llm"] = ragas_llm

        if requires_embeddings(metric_name):
            if constructor_kwargs.get("embeddings") is None:
                constructor_kwargs["embeddings"] = create_default_embeddings()

        if requires_llm_at_score_time(metric_name):
            self._llm = ragas_llm

        self._metric = metric_class(**constructor_kwargs)

    @property
    def kind(self) -> ScorerKind:
        return ScorerKind.THIRD_PARTY

    def _raise_registration_not_supported(self, method_name: str):
        raise MlflowException.invalid_parameter_value(
            f"'{method_name}()' is not supported for third-party scorers like RAGAS. "
            f"Third-party scorers cannot be registered, started, updated, or stopped. "
            f"Use them directly in mlflow.genai.evaluate() instead."
        )

    def register(self, **kwargs):
        self._raise_registration_not_supported("register")

    def start(self, **kwargs):
        self._raise_registration_not_supported("start")

    def update(self, **kwargs):
        self._raise_registration_not_supported("update")

    def stop(self, **kwargs):
        self._raise_registration_not_supported("stop")

    def align(self, **kwargs):
        raise MlflowException.invalid_parameter_value(
            "'align()' is not supported for third-party scorers like RAGAS. "
            "Alignment is only available for MLflow's built-in judges."
        )

    def __call__(
        self,
        *,
        inputs: dict[str, Any] | None = None,
        outputs: Any = None,
        expectations: dict[str, Any] | None = None,
        trace: Trace | None = None,
        session: list[Trace] | None = None,
    ) -> Feedback:
        """
        Evaluate using the wrapped RAGAS metric.

        Args:
            inputs: The input to evaluate
            outputs: The output to evaluate
            expectations: Expected values and context for evaluation
            trace: MLflow trace for evaluation
            session: List of MLflow traces for multi-turn/agentic evaluation

        Returns:
            Feedback object with score, rationale, and metadata
        """
        is_deterministic = not (
            requires_llm_in_constructor(self.name) or requires_llm_at_score_time(self.name)
        )
        if is_deterministic:
            assessment_source = AssessmentSource(
                source_type=AssessmentSourceType.CODE,
                source_id=self.name,
            )
        else:
            assessment_source = AssessmentSource(
                source_type=AssessmentSourceType.LLM_JUDGE,
                source_id=self._model,
            )

        try:
            sample = map_scorer_inputs_to_ragas_sample(
                inputs=inputs,
                outputs=outputs,
                expectations=expectations,
                trace=trace,
                session=session,
                is_agentic_or_multiturn=is_agentic_or_multiturn_metric(self.name),
            )

            result = self._evaluate(sample)
            raw_value = getattr(result, "value", result)
            reason = getattr(result, "reason", None)

            try:
                score = float(raw_value)
            except (TypeError, ValueError):
                score = None

            # RAGAS metrics may have thresholds to map to binary feedback
            threshold = getattr(self._metric, "threshold", None)
            metadata = {FRAMEWORK_METADATA_KEY: "ragas"}

            if score is not None and threshold is not None:
                metadata["threshold"] = threshold
                metadata["score"] = score
                value = CategoricalRating.YES if score >= threshold else CategoricalRating.NO
            else:
                value = score if score is not None else raw_value

            return Feedback(
                name=self.name,
                value=value,
                rationale=reason,
                source=assessment_source,
                trace_id=None,
                metadata=metadata,
            )
        except (KeyError, IndexError, ValueError) as e:
            # RAGAS raises KeyError/IndexError/ValueError when required parameters are missing
            error_msg = str(e).strip("'\"")
            mlflow_error_message = create_mlflow_error_message_from_ragas_param(
                error_msg, self.name
            )
            _logger.error(
                f"Missing required parameter for RAGAS metric {self.name}: {mlflow_error_message}"
            )
            mlflow_error = MlflowException.invalid_parameter_value(mlflow_error_message)

            return Feedback(
                name=self.name,
                error=mlflow_error,
                source=assessment_source,
            )
        except Exception as e:
            _logger.error(f"Error evaluating RAGAS metric {self.name}: {e}")
            return Feedback(
                name=self.name,
                error=e,
                source=assessment_source,
            )

    def _evaluate(self, sample: SingleTurnSample | MultiTurnSample):
        if hasattr(self._metric, "single_turn_score"):
            return self._metric.single_turn_score(sample)
        elif hasattr(self._metric, "ascore"):
            kwargs = {}

            if requires_llm_at_score_time(self.name):
                kwargs["llm"] = self._llm

            if requires_args_from_placeholders(self.name):
                kwargs.update(self._extract_prompt_params_from_sample(sample))

            # need to inspect the signature as each metric has a different one for the ascore method
            sig = inspect.signature(self._metric.ascore)
            for param_name in sig.parameters:
                if param_name == "self":
                    continue

                if hasattr(sample, param_name):
                    value = getattr(sample, param_name)
                    kwargs[param_name] = value

            sync_score = _wrap_async_predict_fn(self._metric.ascore)
            return sync_score(**kwargs)
        else:
            raise MlflowException(f"RAGAS metric {self.name} is not currently supported")

    def _extract_prompt_params_from_sample(
        self, sample: SingleTurnSample | MultiTurnSample
    ) -> dict[str, Any]:
        """
        Extract parameters from the metric's prompt template and get values from sample.

        For metrics like DiscreteMetric where the prompt contains placeholders like
        {response}, {user_input}, etc., this extracts those placeholder names and fetches
        the corresponding values from the sample.
        """
        kwargs = {}
        prompt = getattr(self._metric, "prompt", None)
        if prompt is None:
            return kwargs

        prompt_str = str(prompt)
        placeholders = re.findall(r"\{(\w+)\}", prompt_str)

        for param_name in placeholders:
            if hasattr(sample, param_name):
                value = getattr(sample, param_name)
                if value is not None:
                    kwargs[param_name] = value

        return kwargs

    def _validate_args(self, metric_name: str | None, model: str | None):
        metric_name = metric_name or self.metric_name
        if not requires_llm_in_constructor(metric_name) and model is not None:
            raise MlflowException.invalid_parameter_value(
                f"{metric_name} got an unexpected keyword argument 'model'"
            )


[docs]@experimental(version="3.8.0") @format_docstring(_MODEL_API_DOC) def get_scorer( metric_name: str, model: str | None = None, **metric_kwargs, ) -> RagasScorer: """ Get a RAGAS metric as an MLflow judge. Args: metric_name: Name of the RAGAS metric (e.g., "Faithfulness") model: {{ model }} metric_kwargs: Additional metric-specific parameters (e.g., threshold) Returns: RagasScorer instance that can be called with MLflow's judge interface Examples: .. code-block:: python # LLM-based metric judge = get_scorer("Faithfulness", model="openai:/gpt-4") feedback = judge(inputs="What is MLflow?", outputs="MLflow is a platform...") # Using trace with retrieval context judge = get_scorer("ContextPrecision", model="openai:/gpt-4") feedback = judge(trace=trace) # Deterministic metric (no LLM needed) judge = get_scorer("ExactMatch") feedback = judge(outputs="Paris", expectations={"expected_output": "Paris"}) """ return RagasScorer( metric_name=metric_name, model=model, **metric_kwargs, )
from mlflow.genai.scorers.ragas.scorers import ( AgentGoalAccuracyWithoutReference, AgentGoalAccuracyWithReference, AnswerAccuracy, AnswerRelevancy, AspectCritic, BleuScore, ChrfScore, ContextEntityRecall, ContextPrecision, ContextRecall, ContextRelevance, DiscreteMetric, ExactMatch, FactualCorrectness, Faithfulness, InstanceSpecificRubrics, NoiseSensitivity, NonLLMContextPrecisionWithReference, NonLLMContextRecall, NonLLMStringSimilarity, ResponseGroundedness, RougeScore, RubricsScore, SemanticSimilarity, StringPresence, SummarizationScore, ToolCallAccuracy, ToolCallF1, TopicAdherence, ) __all__ = [ # Core classes "RagasScorer", "get_scorer", # RAG metrics "ContextPrecision", "NonLLMContextPrecisionWithReference", "ContextRecall", "NonLLMContextRecall", "ContextEntityRecall", "NoiseSensitivity", "Faithfulness", "AnswerRelevancy", "SemanticSimilarity", # NVIDIA metrics "AnswerAccuracy", "ContextRelevance", "ResponseGroundedness", # Comparison metrics "FactualCorrectness", "NonLLMStringSimilarity", "BleuScore", "ChrfScore", "RougeScore", "StringPresence", "ExactMatch", # General purpose metrics "AspectCritic", "DiscreteMetric", "RubricsScore", "InstanceSpecificRubrics", # Agentic metrics "TopicAdherence", "ToolCallAccuracy", "ToolCallF1", "AgentGoalAccuracyWithReference", "AgentGoalAccuracyWithoutReference", # Other tasks "SummarizationScore", ]