Source code for mlflow.genai.optimize.optimizers.dspy_optimizer

import contextlib
import importlib.metadata
import importlib.util
import inspect
import io
import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Optional

from packaging.version import Version

from mlflow.entities.model_registry import PromptVersion
from mlflow.exceptions import MlflowException
from mlflow.genai.optimize.optimizers import BasePromptOptimizer
from mlflow.genai.optimize.types import LLMParams, ObjectiveFn, OptimizerConfig, OptimizerOutput
from mlflow.genai.optimize.util import infer_type_from_value
from mlflow.genai.scorers import Scorer
from mlflow.utils.annotations import experimental

if TYPE_CHECKING:
    import dspy
    import pandas as pd

_logger = logging.getLogger(__name__)


[docs]@experimental(version="3.3.0") class DSPyPromptOptimizer(BasePromptOptimizer): def __init__(self, optimizer_config: OptimizerConfig): super().__init__(optimizer_config) if not importlib.util.find_spec("dspy"): raise ImportError("dspy is not installed. Please install it with `pip install dspy`.") dspy_version = importlib.metadata.version("dspy") if Version(dspy_version) < Version("2.6.0"): raise MlflowException( f"Current dspy version {dspy_version} is unsupported. " "Please upgrade to version >= 2.6.0" ) def _parse_model_name(self, model_name: str) -> str: """ Parse model name from URI format to DSPy format. Accepts two formats: - URI format: 'openai:/gpt-4o' -> converted to 'openai/gpt-4o' - DSPy format: 'openai/gpt-4o' -> returned unchanged Raises MlflowException for invalid formats. """ from mlflow.metrics.genai.model_utils import _parse_model_uri if not model_name: raise MlflowException.invalid_parameter_value( "Model name cannot be empty. Please provide a model name in the format " "'<provider>:/<model>' or '<provider>/<model>'." ) try: scheme, path = _parse_model_uri(model_name) return f"{scheme}/{path}" except MlflowException: if "/" in model_name and ":" not in model_name: parts = model_name.split("/") if len(parts) == 2 and parts[0] and parts[1]: return model_name raise MlflowException.invalid_parameter_value( f"Invalid model name format: '{model_name}'. " "Model name must be in one of the following formats:\n" " - '<provider>/<model>' (e.g., 'openai/gpt-4')\n" " - '<provider>:/<model>' (e.g., 'openai:/gpt-4')" )
[docs] def optimize( self, prompt: PromptVersion, target_llm_params: LLMParams, train_data: "pd.DataFrame", scorers: list[Scorer], objective: ObjectiveFn | None = None, eval_data: Optional["pd.DataFrame"] = None, ) -> OptimizerOutput: import dspy _logger.info( f"🎯 Starting prompt optimization for: {prompt.uri}\n" f"⏱️ This may take several minutes or longer depending on dataset size...\n" f"📊 Training with {len(train_data)} examples." ) input_fields = self._get_input_fields(train_data) self._validate_input_fields(input_fields, prompt) output_fields = self._get_output_fields(train_data) lm = dspy.LM( model=self._parse_model_name(target_llm_params.model_name), temperature=target_llm_params.temperature, api_base=target_llm_params.base_uri, ) if self.optimizer_config.optimizer_llm: teacher_lm = dspy.LM( model=self._parse_model_name(self.optimizer_config.optimizer_llm.model_name), temperature=self.optimizer_config.optimizer_llm.temperature, api_base=self.optimizer_config.optimizer_llm.base_uri, ) else: teacher_lm = lm if self.optimizer_config.extract_instructions: instructions = self._extract_instructions(prompt.template, teacher_lm) else: instructions = prompt.template signature = dspy.make_signature( {key: (type_, dspy.InputField()) for key, type_ in input_fields.items()} | {key: (type_, dspy.OutputField()) for key, type_ in output_fields.items()}, instructions, ) # Define main student program program = dspy.Predict(signature) adapter = dspy.JSONAdapter() train_data = self._convert_to_dspy_dataset(train_data) eval_data = self._convert_to_dspy_dataset(eval_data) if eval_data is not None else None with dspy.context(lm=lm, adapter=adapter): return self.run_optimization( prompt=prompt, program=program, metric=self._convert_to_dspy_metric( input_fields, output_fields, scorers, objective ), train_data=train_data, eval_data=eval_data, )
[docs] def run_optimization( self, prompt: PromptVersion, program: "dspy.Module", metric: Callable[["dspy.Example"], float], train_data: list["dspy.Example"], eval_data: list["dspy.Example"], ) -> OptimizerOutput: """ Run the optimization process for the given prompt and program. Parameters ---------- prompt : PromptVersion The prompt version to optimize. program : dspy.Module The DSPy program/module to optimize. metric : Callable[[dspy.Example], float] A callable that computes a metric score for a given example. train_data : list[dspy.Example] List of training examples for optimization. eval_data : list[dspy.Example] List of evaluation examples for validation. Returns ------- OptimizerOutput The result of the optimization, including the optimized prompt and metrics. Raises ------ NotImplementedError This method must be implemented by subclasses. """ raise NotImplementedError( "Subclasses of DSPyPromptOptimizer must implement `run_optimization`." )
def _get_input_fields(self, train_data: "pd.DataFrame") -> dict[str, type]: if "inputs" in train_data.columns: sample_input = train_data["inputs"].values[0] return {k: infer_type_from_value(v) for k, v in sample_input.items()} return {} def _get_output_fields(self, train_data: "pd.DataFrame") -> dict[str, type]: if "expectations" in train_data.columns: sample_output = train_data["expectations"].values[0] return {k: infer_type_from_value(v) for k, v in sample_output.items()} return {} def _convert_to_dspy_dataset(self, data: "pd.DataFrame") -> list["dspy.Example"]: import dspy examples = [] for _, row in data.iterrows(): expectations = row["expectations"] if "expectations" in row else {} examples.append( dspy.Example(**row["inputs"], **expectations).with_inputs(*row["inputs"].keys()) ) return examples def _convert_to_dspy_metric( self, input_fields: dict[str, type], output_fields: dict[str, type], scorers: list[Scorer], objective: ObjectiveFn | None = None, ) -> Callable[["dspy.Example"], float]: def metric(example: "dspy.Example", pred: "dspy.Example", trace=None) -> float: scores = {} inputs = {key: example.get(key) for key in input_fields.keys()} expectations = {key: example.get(key) for key in output_fields.keys()} outputs = {key: pred.get(key) for key in output_fields.keys()} for scorer in scorers: kwargs = {"inputs": inputs, "outputs": outputs, "expectations": expectations} signature = inspect.signature(scorer) kwargs = { key: value for key, value in kwargs.items() if key in signature.parameters } scores[scorer.name] = scorer(**kwargs) if objective is not None: return objective(scores) elif all(isinstance(score, (int, float, bool)) for score in scores.values()): # Use total score by default if no objective is provided return sum(scores.values()) else: non_numerical_scorers = [ k for k, v in scores.items() if not isinstance(v, (int, float, bool)) ] raise MlflowException( f"Scorer [{','.join(non_numerical_scorers)}] return a string, Assessment or a " "list of Assessment. Please provide `objective` function to aggregate " "non-numerical values into a single value for optimization." ) return metric def _validate_input_fields(self, input_fields: dict[str, type], prompt: PromptVersion) -> None: if missing_fields := set(prompt.variables) - set(input_fields.keys()): raise MlflowException( f"Validation failed. Missing prompt variables in dataset: {missing_fields}. " "Please ensure your dataset contains columns for all prompt variables." ) def _extract_instructions(self, template: str | dict[str, Any], lm: "dspy.LM") -> str: import dspy extractor = dspy.Predict( dspy.make_signature( { "prompt": (str, dspy.InputField()), "instruction": (str, dspy.OutputField()), }, "Extract the core instructions from the prompt " "to use as the system message for the LLM.", ) ) with dspy.context(lm=lm): return extractor(prompt=template).instruction @contextlib.contextmanager def _maybe_suppress_stdout_stderr(self): """Context manager for redirecting stdout/stderr based on verbose setting. If verbose is False, redirects output to devnull or StringIO. If verbose is True, doesn't redirect output. """ if not self.optimizer_config.verbose: try: output_sink = open(os.devnull, "w") # noqa: SIM115 except (OSError, IOError): output_sink = io.StringIO() with output_sink: with ( contextlib.redirect_stdout(output_sink), contextlib.redirect_stderr(output_sink), ): yield else: yield