Source code for mlflow.genai.agent_tester

from __future__ import annotations

import inspect
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable

import pydantic

import mlflow
from mlflow.genai.judges.utils.invocation_utils import get_chat_completions_with_structured_output
from mlflow.utils.annotations import experimental

if TYPE_CHECKING:
    from mlflow.entities.trace import Trace
    from mlflow.genai.discovery.entities import DiscoverIssuesResult

_logger = logging.getLogger(__name__)


class _AgentDescription(pydantic.BaseModel):
    description: str = pydantic.Field(
        description="What the agent does — a concise summary of its purpose"
    )
    capabilities: list[str] = pydantic.Field(
        description="Tools, skills, or knowledge areas the agent has"
    )
    limitations: list[str] = pydantic.Field(
        description="Known constraints, boundaries, or things the agent cannot do"
    )

    def __str__(self) -> str:
        capabilities = "\n".join(f"- {c}" for c in self.capabilities)
        limitations = "\n".join(f"- {lim}" for lim in self.limitations)
        return (
            f"Agent description: {self.description}\n\n"
            f"Capabilities:\n{capabilities}\n\n"
            f"Limitations:\n{limitations}"
        )


class _TestCase(pydantic.BaseModel):
    goal: str = pydantic.Field(description="What the simulated user is trying to accomplish")
    persona: str = pydantic.Field(description="A short description of who the simulated user is")
    simulation_guidelines: list[str] = pydantic.Field(
        description="Instructions for how the simulated user should behave"
    )


class _TestCaseList(pydantic.BaseModel):
    test_cases: list[_TestCase] = pydantic.Field(description="List of test cases to simulate")


@experimental(version="3.13.0")
@dataclass
class AgentTestResult:
    """
    Result of :func:`test_agent`.

    Attributes:
        test_cases: Test cases that were generated and simulated.
        agent_description: Natural-language description of the agent
            produced by Step 1.
        simulation_traces: Per-test-case lists of traces produced by
            the conversation simulator.
        issues_result: Full result from the underlying issue detection call.
    """

    test_cases: list[dict[str, Any]]
    agent_description: str
    simulation_traces: list[list[Trace]]
    issues_result: DiscoverIssuesResult

    def __str__(self) -> str:
        issues = self.issues_result.issues
        lines = [self.agent_description, "", f"Issues found: {len(issues)}"]
        lines += [f"  [{issue.severity}] {issue.name}: {issue.description}" for issue in issues]
        return "\n".join(lines)


_DEFAULT_NUM_TEST_CASES = 7

_DESCRIBE_AGENT_SYSTEM_PROMPT = """\
You are an expert at analysing AI agents. Given the agent's own response to \
"describe yourself", extract a structured description."""

_DESCRIBE_AGENT_FROM_TRACES_SYSTEM_PROMPT = """\
You are an expert at analysing AI agents. Given conversation traces from an \
AI agent, extract a structured description of what the agent does, its \
capabilities, and its limitations."""

_DEFAULT_TESTING_GUIDANCE = (
    "Cover a broad mix of the agent's stated capabilities. All test cases should "
    "be realistic. Some should be challenging: ambiguous requests, multi-step "
    "tasks, or requests near the agent's stated limitations."
)

_GENERATE_TEST_CASES_SYSTEM_PROMPT = """\
You are a QA engineer for AI agents. Given a description of an agent, \
generate diverse test cases that exercise different capabilities.

Each test case needs a goal (what the user wants), a persona (who they are), \
and simulation_guidelines (a short list of behavioral instructions for the \
simulated user).

{guidance}

Example output for a weather assistant:

```json
{{
  "test_cases": [
    {{
      "goal": "Get a 7-day weather forecast for Seattle",
      "persona": "A traveler packing for a trip",
      "simulation_guidelines": ["Ask one follow-up about what to wear"]
    }},
    {{
      "goal": "Compare today's weather in Tokyo and London",
      "persona": "Someone scheduling an international call",
      "simulation_guidelines": ["Keep the conversation to 2-3 turns"]
    }}
  ]
}}
```"""


def _get_agent_response_text(predict_fn: Callable[..., Any]) -> str | None:
    """
    Call *predict_fn* with a self-description prompt and return the
    assistant's response as a plain string.
    """

    prompt = [
        {
            "role": "user",
            "content": (
                "What can you do? Describe your capabilities, tools, and limitations in detail."
            ),
        }
    ]

    params = inspect.signature(predict_fn).parameters
    kwarg = "messages" if "messages" in params else "input"

    try:
        result = predict_fn(**{kwarg: prompt})
    except Exception:
        _logger.debug("predict_fn raised when asked to self-describe", exc_info=True)
        return None

    if isinstance(result, str):
        return result

    # Try to extract text the same way the simulator does
    from mlflow.genai.utils.trace_utils import parse_outputs_to_str

    text = parse_outputs_to_str(result)
    if text and text.strip():
        return text

    # Last resort: check the latest trace
    try:
        if trace_id := mlflow.get_last_active_trace_id(thread_local=True):
            from mlflow.genai.utils.trace_utils import extract_outputs_from_trace

            trace = mlflow.get_trace(trace_id)
            if outputs := extract_outputs_from_trace(trace):
                text = parse_outputs_to_str(outputs)
                if text and text.strip():
                    return text
    except Exception:
        _logger.debug("Failed to extract text from last active trace", exc_info=True)

    return None


def _describe_agent_from_response(
    response_text: str,
    model: str,
) -> _AgentDescription:
    from mlflow.types.llm import ChatMessage

    messages = [
        ChatMessage(role="system", content=_DESCRIBE_AGENT_SYSTEM_PROMPT),
        ChatMessage(
            role="user",
            content=f"Agent's self-description:\n\n{response_text}",
        ),
    ]
    return get_chat_completions_with_structured_output(
        model_uri=model,
        messages=messages,
        output_schema=_AgentDescription,
    )


def _describe_agent_from_traces(
    traces: list[Trace],
    model: str,
) -> _AgentDescription:
    from mlflow.genai.discovery.extraction import extract_execution_paths_for_session
    from mlflow.genai.discovery.utils import group_traces_by_session
    from mlflow.genai.utils.trace_utils import (
        extract_available_tools_from_trace,
        resolve_conversation_from_session,
    )
    from mlflow.types.llm import ChatMessage

    sessions = group_traces_by_session(traces)
    context_parts: list[str] = []

    # Sample up to 5 sessions to keep prompt size manageable
    for session_id, session_traces in list(sessions.items())[:5]:
        if conversation := resolve_conversation_from_session(session_traces):
            formatted = "\n".join(f"  {m['role']}: {m['content']}" for m in conversation)
            context_parts.append(f"Conversation ({session_id}):\n{formatted}")

        paths = extract_execution_paths_for_session(session_traces)
        if paths and paths != "(no routing)":
            context_parts.append(f"Execution paths: {paths}")

    # Extract tools from the first trace that has them
    tools_desc = ""
    for trace in traces[:10]:
        if (tools := extract_available_tools_from_trace(trace, model=model)) and (
            tool_names := [t.function.name for t in tools if t.function]
        ):
            tools_desc = f"Available tools: {', '.join(tool_names)}"
            break

    if tools_desc:
        context_parts.append(tools_desc)

    messages = [
        ChatMessage(
            role="system",
            content=_DESCRIBE_AGENT_FROM_TRACES_SYSTEM_PROMPT,
        ),
        ChatMessage(
            role="user",
            content="\n\n".join(context_parts) if context_parts else "(no traces)",
        ),
    ]
    return get_chat_completions_with_structured_output(
        model_uri=model,
        messages=messages,
        output_schema=_AgentDescription,
    )


def _generate_test_cases(
    agent_desc: _AgentDescription,
    model: str,
    num_test_cases: int | None = None,
    guidance: str | None = None,
) -> list[dict[str, Any]]:

    from mlflow.types.llm import ChatMessage

    if num_test_cases is not None and num_test_cases < 1:
        raise ValueError(f"num_test_cases must be >= 1, got {num_test_cases}")
    guidance = guidance or _DEFAULT_TESTING_GUIDANCE
    count = _DEFAULT_NUM_TEST_CASES if num_test_cases is None else num_test_cases

    system_prompt = _GENERATE_TEST_CASES_SYSTEM_PROMPT.format(
        guidance=guidance,
    )
    user_content = str(agent_desc) + f"\n\nGenerate {count} diverse test cases."
    messages = [
        ChatMessage(role="system", content=system_prompt),
        ChatMessage(role="user", content=user_content),
    ]
    result = get_chat_completions_with_structured_output(
        model_uri=model,
        messages=messages,
        output_schema=_TestCaseList,
    )

    return [tc.model_dump() for tc in result.test_cases]


def _resolve_agent_description(
    predict_fn: Callable[..., Any],
    experiment_id: str | None,
    traces: list[Trace] | None,
    model: str,
) -> _AgentDescription:
    agent_desc: _AgentDescription | None = None

    if response_text := _get_agent_response_text(predict_fn):
        try:
            agent_desc = _describe_agent_from_response(response_text, model)
        except Exception:
            _logger.warning("Failed to describe agent from self-description", exc_info=True)

    if (not agent_desc or not agent_desc.capabilities) and (
        existing_traces := (traces or _load_traces(experiment_id))
    ):
        try:
            agent_desc = _describe_agent_from_traces(existing_traces, model)
        except Exception:
            _logger.warning("Failed to describe agent from traces", exc_info=True)

    if not agent_desc or not agent_desc.capabilities:
        return _AgentDescription(
            description="A conversational AI agent",
            capabilities=["general conversation"],
            limitations=["unknown"],
        )
    return agent_desc


def _load_traces(
    experiment_id: str | None,
) -> list[Trace] | None:
    if experiment_id is None:
        return None

    return mlflow.search_traces(
        locations=[experiment_id],
        max_results=50,
        return_type="list",
    )


[docs]@experimental(version="3.13.0") def test_agent( predict_fn: Callable[..., Any], *, experiment_id: str | None = None, traces: list[Trace] | None = None, model: str | None = None, max_turns: int = 10, max_issues: int = 20, num_test_cases: int | None = None, guidance: str | None = None, ) -> AgentTestResult: """ Automatically stress-test a conversational AI agent and discover issues. Runs a multi-step pipeline: 1. **Describe** — asks the agent to describe itself (falls back to analysing existing traces when available). 2. **Generate test cases** — uses an LLM to create diverse, targeted test scenarios from the agent description. 3. **Simulate conversations** — runs each test case through the :class:`~mlflow.genai.simulators.ConversationSimulator`. 4. **Detect issues** — analyses simulation traces with ``discover_issues``. Args: predict_fn: Agent function compatible with :class:`~mlflow.genai.simulators.ConversationSimulator`. Must accept either ``input`` or ``messages`` for conversation history. experiment_id: Optional experiment containing existing traces to help describe the agent. Ignored when ``traces`` is provided. traces: Optional list of existing traces to help describe the agent. model: LLM used for analysis, test generation, and simulation. Defaults to the configured default simulation model when ``None``. max_turns: Maximum conversation turns per test case. max_issues: Maximum number of issues to report. num_test_cases: Number of test cases to generate. Defaults to ``7`` when ``None``. guidance: Optional natural-language guidance for what kinds of queries to test. For example, ``"Focus on multi-step financial workflows"``. When ``None``, uses a default that covers a broad, realistic mix of the agent's capabilities. Returns: An :class:`AgentTestResult` containing discovered issues, generated test cases, the agent description, simulation traces, and the full :class:`~mlflow.genai.discovery.entities.DiscoverIssuesResult`. Example: .. code-block:: python from openai import OpenAI import mlflow client = OpenAI() @mlflow.trace def agent(input: list[dict], **kwargs) -> dict: mlflow.update_current_trace(session_id=kwargs.get("mlflow_session_id")) response = client.chat.completions.create( model="gpt-4o-mini", messages=[{"role": "system", "content": "You are a helpful assistant."}] + input, ) content = response.choices[0].message.content return {"choices": [{"message": {"role": "assistant", "content": content}}]} result = mlflow.genai.test_agent(agent, model="openai:/gpt-4o-mini") print(result) """ from mlflow.genai.discovery.pipeline import discover_issues from mlflow.genai.simulators import ConversationSimulator from mlflow.genai.simulators.simulator import _validate_simulator_predict_fn_signature from mlflow.genai.simulators.utils import get_default_simulation_model _validate_simulator_predict_fn_signature(predict_fn) if not model: model = get_default_simulation_model() _logger.info(f"Using default model: {model}") # ------------------------------------------------------------------ # Step 1: Describe the agent # ------------------------------------------------------------------ _logger.info("Step 1/4: Describing the agent") agent_desc = _resolve_agent_description(predict_fn, experiment_id, traces, model) _logger.info(str(agent_desc)) # ------------------------------------------------------------------ # Step 2: Generate test cases # ------------------------------------------------------------------ _logger.info("Step 2/4: Generating test cases") test_cases = _generate_test_cases(agent_desc, model, num_test_cases, guidance) _logger.info(f"Generated {len(test_cases)} test cases") # ------------------------------------------------------------------ # Step 3: Simulate conversations # ------------------------------------------------------------------ _logger.info("Step 3/4: Simulating conversations") simulator = ConversationSimulator( test_cases=test_cases, max_turns=max_turns, user_model=model, ) simulation_traces = simulator.simulate(predict_fn) # ------------------------------------------------------------------ # Step 4: Detect issues # ------------------------------------------------------------------ _logger.info("Step 4/4: Detecting issues") flat_traces = [t for session in simulation_traces for t in session] issues_result = discover_issues( traces=flat_traces, model=model, max_issues=max_issues, ) return AgentTestResult( test_cases=test_cases, agent_description=str(agent_desc), simulation_traces=simulation_traces, issues_result=issues_result, )