import contextlib
import importlib.metadata
import importlib.util
import inspect
import io
import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Optional
from packaging.version import Version
from mlflow.entities.model_registry import PromptVersion
from mlflow.exceptions import MlflowException
from mlflow.genai.optimize.optimizers import BasePromptOptimizer
from mlflow.genai.optimize.types import LLMParams, ObjectiveFn, OptimizerConfig, OptimizerOutput
from mlflow.genai.optimize.util import infer_type_from_value
from mlflow.genai.scorers import Scorer
from mlflow.utils.annotations import experimental
if TYPE_CHECKING:
    import dspy
    import pandas as pd
_logger = logging.getLogger(__name__)
[docs]@experimental(version="3.3.0")
class DSPyPromptOptimizer(BasePromptOptimizer):
    def __init__(self, optimizer_config: OptimizerConfig):
        super().__init__(optimizer_config)
        if not importlib.util.find_spec("dspy"):
            raise ImportError("dspy is not installed. Please install it with `pip install dspy`.")
        dspy_version = importlib.metadata.version("dspy")
        if Version(dspy_version) < Version("2.6.0"):
            raise MlflowException(
                f"Current dspy version {dspy_version} is unsupported. "
                "Please upgrade to version >= 2.6.0"
            )
    def _parse_model_name(self, model_name: str) -> str:
        """
        Parse model name from URI format to DSPy format.
        Accepts two formats:
        - URI format: 'openai:/gpt-4o' -> converted to 'openai/gpt-4o'
        - DSPy format: 'openai/gpt-4o' -> returned unchanged
        Raises MlflowException for invalid formats.
        """
        from mlflow.metrics.genai.model_utils import _parse_model_uri
        if not model_name:
            raise MlflowException.invalid_parameter_value(
                "Model name cannot be empty. Please provide a model name in the format "
                "'<provider>:/<model>' or '<provider>/<model>'."
            )
        try:
            scheme, path = _parse_model_uri(model_name)
            return f"{scheme}/{path}"
        except MlflowException:
            if "/" in model_name and ":" not in model_name:
                parts = model_name.split("/")
                if len(parts) == 2 and parts[0] and parts[1]:
                    return model_name
            raise MlflowException.invalid_parameter_value(
                f"Invalid model name format: '{model_name}'. "
                "Model name must be in one of the following formats:\n"
                "  - '<provider>/<model>' (e.g., 'openai/gpt-4')\n"
                "  - '<provider>:/<model>' (e.g., 'openai:/gpt-4')"
            )
[docs]    def optimize(
        self,
        prompt: PromptVersion,
        target_llm_params: LLMParams,
        train_data: "pd.DataFrame",
        scorers: list[Scorer],
        objective: ObjectiveFn | None = None,
        eval_data: Optional["pd.DataFrame"] = None,
    ) -> OptimizerOutput:
        import dspy
        _logger.info(
            f"🎯 Starting prompt optimization for: {prompt.uri}\n"
            f"⏱️ This may take several minutes or longer depending on dataset size...\n"
            f"📊 Training with {len(train_data)} examples."
        )
        input_fields = self._get_input_fields(train_data)
        self._validate_input_fields(input_fields, prompt)
        output_fields = self._get_output_fields(train_data)
        lm = dspy.LM(
            model=self._parse_model_name(target_llm_params.model_name),
            temperature=target_llm_params.temperature,
            api_base=target_llm_params.base_uri,
        )
        if self.optimizer_config.optimizer_llm:
            teacher_lm = dspy.LM(
                model=self._parse_model_name(self.optimizer_config.optimizer_llm.model_name),
                temperature=self.optimizer_config.optimizer_llm.temperature,
                api_base=self.optimizer_config.optimizer_llm.base_uri,
            )
        else:
            teacher_lm = lm
        if self.optimizer_config.extract_instructions:
            instructions = self._extract_instructions(prompt.template, teacher_lm)
        else:
            instructions = prompt.template
        signature = dspy.make_signature(
            {key: (type_, dspy.InputField()) for key, type_ in input_fields.items()}
            | {key: (type_, dspy.OutputField()) for key, type_ in output_fields.items()},
            instructions,
        )
        # Define main student program
        program = dspy.Predict(signature)
        adapter = dspy.JSONAdapter()
        train_data = self._convert_to_dspy_dataset(train_data)
        eval_data = self._convert_to_dspy_dataset(eval_data) if eval_data is not None else None
        with dspy.context(lm=lm, adapter=adapter):
            return self.run_optimization(
                prompt=prompt,
                program=program,
                metric=self._convert_to_dspy_metric(
                    input_fields, output_fields, scorers, objective
                ),
                train_data=train_data,
                eval_data=eval_data,
            ) 
[docs]    def run_optimization(
        self,
        prompt: PromptVersion,
        program: "dspy.Module",
        metric: Callable[["dspy.Example"], float],
        train_data: list["dspy.Example"],
        eval_data: list["dspy.Example"],
    ) -> OptimizerOutput:
        """
        Run the optimization process for the given prompt and program.
        Parameters
        ----------
        prompt : PromptVersion
            The prompt version to optimize.
        program : dspy.Module
            The DSPy program/module to optimize.
        metric : Callable[[dspy.Example], float]
            A callable that computes a metric score for a given example.
        train_data : list[dspy.Example]
            List of training examples for optimization.
        eval_data : list[dspy.Example]
            List of evaluation examples for validation.
        Returns
        -------
        OptimizerOutput
            The result of the optimization, including the optimized prompt and metrics.
        Raises
        ------
        NotImplementedError
            This method must be implemented by subclasses.
        """
        raise NotImplementedError(
            "Subclasses of DSPyPromptOptimizer must implement `run_optimization`."
        ) 
    def _get_input_fields(self, train_data: "pd.DataFrame") -> dict[str, type]:
        if "inputs" in train_data.columns:
            sample_input = train_data["inputs"].values[0]
            return {k: infer_type_from_value(v) for k, v in sample_input.items()}
        return {}
    def _get_output_fields(self, train_data: "pd.DataFrame") -> dict[str, type]:
        if "expectations" in train_data.columns:
            sample_output = train_data["expectations"].values[0]
            return {k: infer_type_from_value(v) for k, v in sample_output.items()}
        return {}
    def _convert_to_dspy_dataset(self, data: "pd.DataFrame") -> list["dspy.Example"]:
        import dspy
        examples = []
        for _, row in data.iterrows():
            expectations = row["expectations"] if "expectations" in row else {}
            examples.append(
                dspy.Example(**row["inputs"], **expectations).with_inputs(*row["inputs"].keys())
            )
        return examples
    def _convert_to_dspy_metric(
        self,
        input_fields: dict[str, type],
        output_fields: dict[str, type],
        scorers: list[Scorer],
        objective: ObjectiveFn | None = None,
    ) -> Callable[["dspy.Example"], float]:
        def metric(example: "dspy.Example", pred: "dspy.Example", trace=None) -> float:
            scores = {}
            inputs = {key: example.get(key) for key in input_fields.keys()}
            expectations = {key: example.get(key) for key in output_fields.keys()}
            outputs = {key: pred.get(key) for key in output_fields.keys()}
            for scorer in scorers:
                kwargs = {"inputs": inputs, "outputs": outputs, "expectations": expectations}
                signature = inspect.signature(scorer)
                kwargs = {
                    key: value for key, value in kwargs.items() if key in signature.parameters
                }
                scores[scorer.name] = scorer(**kwargs)
            if objective is not None:
                return objective(scores)
            elif all(isinstance(score, (int, float, bool)) for score in scores.values()):
                # Use total score by default if no objective is provided
                return sum(scores.values())
            else:
                non_numerical_scorers = [
                    k for k, v in scores.items() if not isinstance(v, (int, float, bool))
                ]
                raise MlflowException(
                    f"Scorer [{','.join(non_numerical_scorers)}] return a string, Assessment or a "
                    "list of Assessment. Please provide `objective` function to aggregate "
                    "non-numerical values into a single value for optimization."
                )
        return metric
    def _validate_input_fields(self, input_fields: dict[str, type], prompt: PromptVersion) -> None:
        if missing_fields := set(prompt.variables) - set(input_fields.keys()):
            raise MlflowException(
                f"Validation failed. Missing prompt variables in dataset: {missing_fields}. "
                "Please ensure your dataset contains columns for all prompt variables."
            )
    def _extract_instructions(self, template: str | dict[str, Any], lm: "dspy.LM") -> str:
        import dspy
        extractor = dspy.Predict(
            dspy.make_signature(
                {
                    "prompt": (str, dspy.InputField()),
                    "instruction": (str, dspy.OutputField()),
                },
                "Extract the core instructions from the prompt "
                "to use as the system message for the LLM.",
            )
        )
        with dspy.context(lm=lm):
            return extractor(prompt=template).instruction
    @contextlib.contextmanager
    def _maybe_suppress_stdout_stderr(self):
        """Context manager for redirecting stdout/stderr based on verbose setting.
        If verbose is False, redirects output to devnull or StringIO.
        If verbose is True, doesn't redirect output.
        """
        if not self.optimizer_config.verbose:
            try:
                output_sink = open(os.devnull, "w")  # noqa: SIM115
            except (OSError, IOError):
                output_sink = io.StringIO()
            with output_sink:
                with (
                    contextlib.redirect_stdout(output_sink),
                    contextlib.redirect_stderr(output_sink),
                ):
                    yield
        else:
            yield