Optimizing Prompts for LangGraph

This guide demonstrates how to leverage mlflow.genai.optimize_prompts() alongside LangGraph to enhance your agent's prompts automatically. The mlflow.genai.optimize_prompts() API is framework-agnostic, enabling you to perform end-to-end prompt optimization of your graphs from any framework using state-of-the-art techniques. For more information about the API, please visit Optimize Prompts.
Prerequisites​
bash
pip install -U langgraph langchain langchain-openai mlflow gepa litellm
Set your OpenAI API key:
bash
export OPENAI_API_KEY="your-api-key"
Set tracking server and MLflow experiment:
python
import mlflow
mlflow.set_tracking_uri("http://localhost:5000")
mlflow.set_experiment("LangGraph Optimization")
Basic Example​
Here's a complete example of optimizing a customer support agent built with LangGraph. This example demonstrates how to optimize system and user prompts in a stateful graph workflow, showing the minimal code changes needed to integrate prompt optimization into your LangGraph applications.
python
import mlflow
from mlflow.genai.scorers import Correctness
from mlflow.genai.optimize.optimizers import GepaPromptOptimizer
from langgraph.graph import StateGraph, START, END
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, HumanMessage
from typing_extensions import TypedDict, Annotated
import operator
# Step 1: Register your initial prompts
system_prompt = mlflow.genai.register_prompt(
name="customer-support-system",
template="You are a helpful customer support agent for an e-commerce platform. "
"Assist customers with their questions about orders, returns, and products.",
)
user_prompt = mlflow.genai.register_prompt(
name="customer-support-query",
template="Customer inquiry: {{query}}",
)
# Step 2: Define state schema for LangGraph
class AgentState(TypedDict):
messages: Annotated[list, operator.add]
query: str
llm_calls: int
# Step 3: Create a prediction function that uses LangGraph
def predict_fn(query):
# Load prompts from registry
system_prompt = mlflow.genai.load_prompt("prompts:/customer-support-system@latest")
user_prompt = mlflow.genai.load_prompt("prompts:/customer-support-query@latest")
# Initialize model
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
# Define the LLM node
def llm_node(state: AgentState):
formatted_user_msg = user_prompt.format(query=state["query"])
messages = [
SystemMessage(content=system_prompt.template),
HumanMessage(content=formatted_user_msg),
]
response = model.invoke(messages)
return {
"messages": [response],
"llm_calls": state.get("llm_calls", 0) + 1,
}
# Build the graph
graph_builder = StateGraph(AgentState)
graph_builder.add_node("llm_call", llm_node)
graph_builder.add_edge(START, "llm_call")
graph_builder.add_edge("llm_call", END)
# Compile and run
agent = graph_builder.compile()
result = agent.invoke({"query": query, "messages": [], "llm_calls": 0})
return result["messages"][-1].content
# Step 4: Prepare training data
dataset = [
{
"inputs": {"query": "Where is my order #12345?"},
"expectations": {
"expected_response": "I'd be happy to help you track your order #12345. "
"Please check your email for a tracking link, or I can look it up for you if you provide your email address."
},
},
{
"inputs": {"query": "How do I return a defective product?"},
"expectations": {
"expected_response": "I'm sorry to hear your product is defective. You can initiate a return "
"through your account's order history within 30 days of purchase. We'll send you a prepaid shipping label."
},
},
{
"inputs": {"query": "Do you have this item in blue?"},
"expectations": {
"expected_response": "I'd be happy to check product availability for you. "
"Could you please provide the product name or SKU so I can verify if it's available in blue?"
},
},
# more data...
]
# Step 5: Optimize the prompts
result = mlflow.genai.optimize_prompts(
predict_fn=predict_fn,
train_data=dataset,
prompt_uris=[system_prompt.uri, user_prompt.uri],
optimizer=GepaPromptOptimizer(reflection_model="openai:/gpt-4o"),
scorers=[Correctness(model="openai:/gpt-4o")],
)
# Step 6: Use the optimized prompts
optimized_system_prompt = result.optimized_prompts[0]
optimized_user_prompt = result.optimized_prompts[1]
print(f"Optimized system prompt URI: {optimized_system_prompt.uri}")
print(f"Optimized system template: {optimized_system_prompt.template}")
print(f"Optimized user prompt URI: {optimized_user_prompt.uri}")
print(f"Optimized user template: {optimized_user_prompt.template}")
# Since your graph already uses @latest, it will automatically use the optimized prompts
predict_fn("Can I get a refund for order #67890?")