from __future__ import annotations
import json
import re
from typing import Any, Optional, Union
from pydantic import BaseModel, ValidationError
from mlflow.entities.model_registry._model_registry_entity import _ModelRegistryEntity
from mlflow.entities.model_registry.model_version_tag import ModelVersionTag
from mlflow.exceptions import MlflowException
from mlflow.prompt.constants import (
IS_PROMPT_TAG_KEY,
PROMPT_TEMPLATE_VARIABLE_PATTERN,
PROMPT_TEXT_DISPLAY_LIMIT,
PROMPT_TEXT_TAG_KEY,
PROMPT_TYPE_CHAT,
PROMPT_TYPE_TAG_KEY,
PROMPT_TYPE_TEXT,
RESPONSE_FORMAT_TAG_KEY,
)
# Alias type
PromptVersionTag = ModelVersionTag
def _is_reserved_tag(key: str) -> bool:
return key in {
IS_PROMPT_TAG_KEY,
PROMPT_TEXT_TAG_KEY,
PROMPT_TYPE_TAG_KEY,
RESPONSE_FORMAT_TAG_KEY,
}
[docs]class PromptVersion(_ModelRegistryEntity):
"""
An entity representing a specific version of a prompt with its template content.
Args:
name: The name of the prompt.
version: The version number of the prompt.
template: The template content of the prompt. Can be either:
- A string containing text with variables enclosed in double curly braces,
e.g. {{variable}}, which will be replaced with actual values by the `format` method.
MLflow uses the same variable naming rules as Jinja2:
https://jinja.palletsprojects.com/en/stable/api/#notes-on-identifiers
- A list of dictionaries representing chat messages, where each message has
'role' and 'content' keys (e.g., [{"role": "user", "content": "Hello {{name}}"}])
response_format: Optional Pydantic class or dictionary defining the expected response
structure. This can be used to specify the schema for structured outputs.
commit_message: The commit message for the prompt version. Optional.
creation_timestamp: Timestamp of the prompt creation. Optional.
tags: A dictionary of tags associated with the **prompt version**.
This is useful for storing version-specific information, such as the author of
the changes. Optional.
aliases: List of aliases for this prompt version. Optional.
last_updated_timestamp: Timestamp of last update. Optional.
user_id: User ID that created this prompt version. Optional.
"""
def __init__(
self,
name: str,
version: int,
template: Union[str, list[dict[str, Any]]],
commit_message: Optional[str] = None,
creation_timestamp: Optional[int] = None,
tags: Optional[dict[str, str]] = None,
aliases: Optional[list[str]] = None,
last_updated_timestamp: Optional[int] = None,
user_id: Optional[str] = None,
response_format: Optional[Union[BaseModel, dict[str, Any]]] = None,
):
from mlflow.types.chat import ChatMessage
super().__init__()
# Core PromptVersion attributes
self._name: str = name
self._version: str = str(version) # Store as string internally
self._creation_time: int = creation_timestamp or 0
# Initialize tags first
tags = tags or {}
# Determine prompt type and set it
if isinstance(template, list) and len(template) > 0:
try:
for msg in template:
ChatMessage.model_validate(msg)
except ValidationError as e:
raise ValueError("Template must be a list of dicts with role and content") from e
self._prompt_type = PROMPT_TYPE_CHAT
tags[PROMPT_TYPE_TAG_KEY] = PROMPT_TYPE_CHAT
else:
self._prompt_type = PROMPT_TYPE_TEXT
tags[PROMPT_TYPE_TAG_KEY] = PROMPT_TYPE_TEXT
# Store template text as a tag
tags[PROMPT_TEXT_TAG_KEY] = template if isinstance(template, str) else json.dumps(template)
tags[IS_PROMPT_TAG_KEY] = "true"
if response_format:
tags[RESPONSE_FORMAT_TAG_KEY] = json.dumps(
self.convert_response_format_to_dict(response_format)
)
# Store the tags dict
self._tags: dict[str, str] = tags
template_text = template if isinstance(template, str) else json.dumps(template)
self._variables = set(PROMPT_TEMPLATE_VARIABLE_PATTERN.findall(template_text))
self._last_updated_timestamp: Optional[int] = last_updated_timestamp
self._description: Optional[str] = commit_message
self._user_id: Optional[str] = user_id
self._aliases: list[str] = aliases or []
def __repr__(self) -> str:
if self.is_text_prompt:
text = (
self.template[:PROMPT_TEXT_DISPLAY_LIMIT] + "..."
if len(self.template) > PROMPT_TEXT_DISPLAY_LIMIT
else self.template
)
else:
message = json.dumps(self.template)
text = (
message[:PROMPT_TEXT_DISPLAY_LIMIT] + "..."
if len(message) > PROMPT_TEXT_DISPLAY_LIMIT
else message
)
return f"PromptVersion(name={self.name}, version={self.version}, template={text})"
# Core PromptVersion properties
@property
def template(self) -> Union[str, list[dict[str, Any]]]:
"""
Return the template content of the prompt.
Returns:
Either a string (for text prompts) or a list of chat message dictionaries
(for chat prompts) with 'role' and 'content' keys.
"""
if self.is_text_prompt:
return self._tags[PROMPT_TEXT_TAG_KEY]
else:
return json.loads(self._tags[PROMPT_TEXT_TAG_KEY])
@property
def is_text_prompt(self) -> bool:
"""
Return True if the prompt is a text prompt, False if it's a chat prompt.
Returns:
True for text prompts (string templates), False for chat prompts (list of messages).
"""
return self._prompt_type == PROMPT_TYPE_TEXT
@property
def response_format(self) -> Optional[dict[str, Any]]:
"""
Return the response format specification for the prompt.
Returns:
A dictionary defining the expected response structure, or None if no
response format is specified. This can be used to validate or structure
the output from LLM calls.
"""
if RESPONSE_FORMAT_TAG_KEY not in self._tags:
return None
return json.loads(self._tags[RESPONSE_FORMAT_TAG_KEY])
@property
def variables(self) -> set[str]:
"""
Return a list of variables in the template text.
The value must be enclosed in double curly braces, e.g. {{variable}}.
"""
return self._variables
@property
def commit_message(self) -> Optional[str]:
"""
Return the commit message of the prompt version.
"""
return self.description
@property
def tags(self) -> dict[str, str]:
"""
Return the version-level tags.
"""
return {key: value for key, value in self._tags.items() if not _is_reserved_tag(key)}
@property
def uri(self) -> str:
"""Return the URI of the prompt."""
return f"prompts:/{self.name}/{self.version}"
@property
def name(self) -> str:
"""String. Unique name within Model Registry."""
return self._name
@name.setter
def name(self, new_name: str):
self._name = new_name
@property
def version(self) -> int:
"""Version"""
return int(self._version)
@property
def creation_timestamp(self) -> int:
"""Integer. Prompt version creation timestamp (milliseconds since the Unix epoch)."""
return self._creation_time
@property
def last_updated_timestamp(self) -> Optional[int]:
"""Integer. Timestamp of last update for this prompt version (milliseconds since the Unix
epoch).
"""
return self._last_updated_timestamp
@last_updated_timestamp.setter
def last_updated_timestamp(self, updated_timestamp: int):
self._last_updated_timestamp = updated_timestamp
@property
def description(self) -> Optional[str]:
"""String. Description"""
return self._description
@description.setter
def description(self, description: str):
self._description = description
@property
def user_id(self) -> Optional[str]:
"""String. User ID that created this prompt version."""
return self._user_id
@property
def aliases(self) -> list[str]:
"""List of aliases (string) for the current prompt version."""
return self._aliases
@aliases.setter
def aliases(self, aliases: list[str]):
self._aliases = aliases
# Methods
@classmethod
def _properties(cls) -> list[str]:
# aggregate with base class properties since cls.__dict__ does not do it automatically
return sorted(cls._get_properties_helper())
def _add_tag(self, tag: ModelVersionTag):
self._tags[tag.key] = tag.value